Compare commits

..

7 Commits

Author SHA1 Message Date
joachim-danswer
013bed3157 fix 2025-06-30 15:19:42 -07:00
joachim-danswer
289f27c43a updates 2025-06-30 15:06:12 -07:00
joachim-danswer
736a9bd332 erase history 2025-06-30 09:01:23 -07:00
joachim-danswer
8bcad415bb nit 2025-06-30 08:16:43 -07:00
joachim-danswer
93e6e4a089 mypy nits 2025-06-30 07:49:55 -07:00
joachim-danswer
ed0062dce0 fix 2025-06-30 02:45:03 -07:00
joachim-danswer
6e8bf3120c hackathon v1 changes 2025-06-30 01:39:36 -07:00
365 changed files with 6656 additions and 17335 deletions

View File

@@ -13,14 +13,6 @@ env:
# MinIO
S3_ENDPOINT_URL: "http://localhost:9004"
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
discover-test-dirs:
runs-on: ubuntu-latest

View File

@@ -1,38 +0,0 @@
name: PR Labeler
on:
pull_request_target:
branches:
- main
types:
- opened
- reopened
- synchronize
- edited
permissions:
contents: read
pull-requests: write
jobs:
validate_pr_title:
runs-on: ubuntu-latest
steps:
- name: Check PR title for Conventional Commits
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
echo "PR Title: $PR_TITLE"
if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then
echo "::error::❌ Your PR title does not follow the Conventional Commits format.
This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history.
Please update your PR title to follow the Conventional Commits style.
Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits
**Here are some examples of valid PR titles:**
- feat: add user authentication
- fix(login): handle null password error
- docs(readme): update installation instructions"
exit 1
fi

View File

@@ -47,7 +47,7 @@ jobs:
-i /local/openapi.json \
-g python \
-o /local/onyx_openapi_client \
--package-name onyx_openapi_client \
--package-name onyx_openapi_client
- name: Run MyPy
run: |

View File

@@ -16,8 +16,8 @@ env:
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_IS_CLOUD: ${{ secrets.CONFLUENCE_IS_CLOUD }}
CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -53,12 +53,6 @@ env:
# Hubspot
HUBSPOT_ACCESS_TOKEN: ${{ secrets.HUBSPOT_ACCESS_TOKEN }}
# IMAP
IMAP_HOST: ${{ secrets.IMAP_HOST }}
IMAP_USERNAME: ${{ secrets.IMAP_USERNAME }}
IMAP_PASSWORD: ${{ secrets.IMAP_PASSWORD }}
IMAP_MAILBOXES: ${{ secrets.IMAP_MAILBOXES }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}

View File

@@ -45,9 +45,8 @@ PYTHONPATH=../backend
PYTHONUNBUFFERED=1
# Internet Search
# Internet Search
BING_API_KEY=<REPLACE THIS>
EXA_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features

View File

@@ -24,8 +24,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
@@ -46,8 +46,8 @@
"Celery primary",
"Celery light",
"Celery heavy",
"Celery docfetching",
"Celery docprocessing",
"Celery indexing",
"Celery user files indexing",
"Celery beat",
"Celery monitoring"
],
@@ -226,66 +226,35 @@
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery docfetching",
"name": "Celery indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=indexing@%n",
"-Q",
"connector_indexing"
],
"presentation": {
"group": "2"
"group": "2"
},
"consoleTitle": "Celery docfetching Console",
"justMyCode": false
},
{
"name": "Celery docprocessing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
"-Q",
"docprocessing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
@@ -334,6 +303,35 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery user files indexing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_files_indexing@%n",
"-Q",
"user_files_indexing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user files indexing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",
@@ -428,7 +426,7 @@
},
"args": [
"--filename",
"generated/openapi.json"
"generated/openapi.json",
]
},
{

View File

@@ -59,7 +59,6 @@ Onyx being a fully functional app, relies on some external software, specificall
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [MinIO](https://min.io/) (File Store)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
@@ -172,10 +171,10 @@ Otherwise, you can follow the instructions below to run the application for deve
You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis with:
```bash
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache minio
docker compose -f docker-compose.dev.yml -p onyx-stack up -d index relational_db cache
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)

View File

@@ -23,7 +23,7 @@ from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import (
MULTI_TENANT,
POSTGRES_DEFAULT_SCHEMA,
POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE,
TENANT_ID_PREFIX,
)
from onyx.db.models import Base
@@ -271,7 +271,7 @@ async def run_async_migrations() -> None:
) = get_schema_options()
if not schemas and not MULTI_TENANT:
schemas = [POSTGRES_DEFAULT_SCHEMA]
schemas = [POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE]
# without init_engine, subsequent engine calls fail hard intentionally
SqlEngine.init_engine(pool_size=20, max_overflow=5)

View File

@@ -1,72 +0,0 @@
"""add federated connector tables
Revision ID: 0816326d83aa
Revises: 12635f6655b7
Create Date: 2025-06-29 14:09:45.109518
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "0816326d83aa"
down_revision = "12635f6655b7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create federated_connector table
op.create_table(
"federated_connector",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("source", sa.String(), nullable=False),
sa.Column("credentials", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector_oauth_token table
op.create_table(
"federated_connector_oauth_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("token", sa.LargeBinary(), nullable=False),
sa.Column("expires_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create federated_connector__document_set table
op.create_table(
"federated_connector__document_set",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("federated_connector_id", sa.Integer(), nullable=False),
sa.Column("document_set_id", sa.Integer(), nullable=False),
sa.Column("entities", postgresql.JSONB(), nullable=False),
sa.ForeignKeyConstraint(
["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(
["document_set_id"], ["document_set.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"federated_connector_id",
"document_set_id",
name="uq_federated_connector_document_set",
),
)
def downgrade() -> None:
# Drop tables in reverse order due to foreign key dependencies
op.drop_table("federated_connector__document_set")
op.drop_table("federated_connector_oauth_token")
op.drop_table("federated_connector")

View File

@@ -1,596 +0,0 @@
"""drive-canonical-ids
Revision ID: 12635f6655b7
Revises: 58c50ef19f08
Create Date: 2025-06-20 14:44:54.241159
"""
from alembic import op
import sqlalchemy as sa
from urllib.parse import urlparse, urlunparse
from httpx import HTTPStatusError
import httpx
from onyx.document_index.factory import get_default_document_index
from onyx.db.search_settings import SearchSettings
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.utils.logger import setup_logger
import os
logger = setup_logger()
# revision identifiers, used by Alembic.
revision = "12635f6655b7"
down_revision = "58c50ef19f08"
branch_labels = None
depends_on = None
SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true"
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
result = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_fetch = result.fetchall()
search_settings = (
SearchSettings(**search_settings_fetch[0]._asdict())
if search_settings_fetch
else None
)
result2 = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_future_fetch = result2.fetchall()
search_settings_future = (
SearchSettings(**search_settings_future_fetch[0]._asdict())
if search_settings_future_fetch
else None
)
if not isinstance(search_settings, SearchSettings):
raise RuntimeError(
"current search settings is of type " + str(type(search_settings))
)
if (
not isinstance(search_settings_future, SearchSettings)
and search_settings_future is not None
):
raise RuntimeError(
"future search settings is of type " + str(type(search_settings_future))
)
return search_settings, search_settings_future
def normalize_google_drive_url(url: str) -> str:
"""Remove query parameters from Google Drive URLs to create canonical document IDs.
NOTE: copied from drive doc_conversion.py
"""
parsed_url = urlparse(url)
parsed_url = parsed_url._replace(query="")
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
def get_google_drive_documents_from_database() -> list[dict]:
"""Get all Google Drive documents from the database."""
bind = op.get_bind()
result = bind.execute(
sa.text(
"""
SELECT d.id
FROM document d
JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id
JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id
AND dcc.credential_id = cc.credential_id
JOIN connector c ON cc.connector_id = c.id
WHERE c.source = 'GOOGLE_DRIVE'
"""
)
)
documents = []
for row in result:
documents.append({"document_id": row.id})
return documents
def update_document_id_in_database(
old_doc_id: str, new_doc_id: str, index_name: str
) -> None:
"""Update document IDs in all relevant database tables using copy-and-swap approach."""
bind = op.get_bind()
# print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}")
# Check if new document ID already exists
result = bind.execute(
sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"),
{"new_id": new_doc_id},
)
row = result.fetchone()
if row and row[0] > 0:
# print(f"Document with ID {new_doc_id} already exists, deleting old one")
delete_document_from_db(old_doc_id, index_name)
return
# Step 1: Create a new document row with the new ID (copy all fields from old row)
# Use a conservative approach to handle columns that might not exist in all installations
try:
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners,
external_user_emails, external_user_group_ids, is_public,
chunk_count, last_modified, last_synced, kg_stage, kg_processing_time
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
# If the full INSERT fails, try a more basic version with only core columns
logger.warning(f"Full INSERT failed, trying basic version: {e}")
bind.execute(
sa.text(
"""
INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners)
SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id,
link, doc_updated_at, primary_owners, secondary_owners
FROM document
WHERE id = :old_id
"""
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# Step 2: Update all foreign key references to point to the new ID
# Update document_by_connector_credential_pair table
bind.execute(
sa.text(
"UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}")
# Update search_doc table (stores search results for chat replay)
# This is critical for agent functionality
bind.execute(
sa.text(
"UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}")
# Update document_retrieval_feedback table (user feedback on documents)
bind.execute(
sa.text(
"UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}")
# Update document__tag table (document-tag relationships)
bind.execute(
sa.text(
"UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}")
# Update user_file table (user uploaded files linked to documents)
bind.execute(
sa.text(
"UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}")
# Update KG and chunk_stats tables (these may not exist in all installations)
try:
# Update kg_entity table
bind.execute(
sa.text(
"UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}")
# Update kg_entity_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship table
bind.execute(
sa.text(
"UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}")
# Update kg_relationship_extraction_staging table
bind.execute(
sa.text(
"UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats table
bind.execute(
sa.text(
"UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id"
),
{"new_id": new_doc_id, "old_id": old_doc_id},
)
# print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}")
# Update chunk_stats ID field which includes document_id
bind.execute(
sa.text(
"""
UPDATE chunk_stats
SET id = REPLACE(id, :old_id, :new_id)
WHERE id LIKE :old_id_pattern
"""
),
{
"new_id": new_doc_id,
"old_id": old_doc_id,
"old_id_pattern": f"{old_doc_id}__%",
},
)
# print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}")
except Exception as e:
logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}")
# Step 3: Delete the old document row (this should now be safe since all FKs point to new row)
bind.execute(
sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id}
)
# print(f"Successfully deleted document {old_doc_id} from database")
def _visit_chunks(
*,
http_client: httpx.Client,
index_name: str,
selection: str,
continuation: str | None = None,
) -> tuple[list[dict], str | None]:
"""Helper that calls the /document/v1 visit API once and returns (docs, next_token)."""
# Use the same URL as the document API, but with visit-specific params
base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
params: dict[str, str] = {
"selection": selection,
"wantedDocumentCount": "1000",
}
if continuation:
params["continuation"] = continuation
# print(f"Visiting chunks for selection '{selection}' with params {params}")
resp = http_client.get(base_url, params=params, timeout=None)
# print(f"Visited chunks for document {selection}")
resp.raise_for_status()
payload = resp.json()
return payload.get("documents", []), payload.get("continuation")
def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None:
"""Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset)."""
total_deleted = 0
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
try:
resp = http_client.delete(delete_url)
resp.raise_for_status()
total_deleted += 1
except Exception as e:
print(f"Failed to delete chunk {vespa_doc_uuid}: {e}")
if not continuation:
break
def update_document_id_in_vespa(
index_name: str, old_doc_id: str, new_doc_id: str
) -> None:
"""Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging."""
clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id)
# Use exact match instead of contains - Document Selector Language doesn't support contains
selection = f'{index_name}.document_id=="{old_doc_id}"'
with get_vespa_http_client() as http_client:
continuation: str | None = None
while True:
# print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}")
docs, continuation = _visit_chunks(
http_client=http_client,
index_name=index_name,
selection=selection,
continuation=continuation,
)
if not docs:
break
for doc in docs:
vespa_full_id = doc.get("id")
if not vespa_full_id:
continue
vespa_doc_uuid = vespa_full_id.split("::")[-1]
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}"
update_request = {
"fields": {"document_id": {"assign": clean_new_doc_id}}
}
try:
resp = http_client.put(vespa_url, json=update_request)
resp.raise_for_status()
except Exception as e:
print(f"Failed to update chunk {vespa_doc_uuid}: {e}")
raise
if not continuation:
break
def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
# Delete all foreign key references first, then delete the document
try:
bind = op.get_bind()
# Delete from agent-related tables first (order matters due to foreign keys)
# Delete from agent__sub_query__search_doc first since it references search_doc
bind.execute(
sa.text(
"""
DELETE FROM agent__sub_query__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Delete from chat_message__search_doc
bind.execute(
sa.text(
"""
DELETE FROM chat_message__search_doc
WHERE search_doc_id IN (
SELECT id FROM search_doc WHERE document_id = :doc_id
)
"""
),
{"doc_id": current_doc_id},
)
# Now we can safely delete from search_doc
bind.execute(
sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from document_by_connector_credential_pair
bind.execute(
sa.text(
"DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id"
),
{"doc_id": current_doc_id},
)
# Delete from other tables that reference this document
bind.execute(
sa.text(
"DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM user_file WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete from KG tables if they exist
try:
bind.execute(
sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text(
"DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id"
),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"),
{"doc_id": current_doc_id},
)
bind.execute(
sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"),
{"doc_id_pattern": f"{current_doc_id}__%"},
)
except Exception as e:
logger.warning(
f"Some KG/chunk tables may not exist or failed to delete from: {e}"
)
# Finally delete the document itself
bind.execute(
sa.text("DELETE FROM document WHERE id = :doc_id"),
{"doc_id": current_doc_id},
)
# Delete chunks from vespa
delete_document_chunks_from_vespa(index_name, current_doc_id)
except Exception as e:
print(f"Failed to delete duplicate document {current_doc_id}: {e}")
# Continue with other documents instead of failing the entire migration
def upgrade() -> None:
if SKIP_CANON_DRIVE_IDS:
return
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings,
future_search_settings,
)
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"
# Get all Google Drive documents from the database (this is faster and more reliable)
gdrive_documents = get_google_drive_documents_from_database()
if not gdrive_documents:
return
# Track normalized document IDs to detect duplicates
all_normalized_doc_ids = set()
updated_count = 0
for doc_info in gdrive_documents:
current_doc_id = doc_info["document_id"]
normalized_doc_id = normalize_google_drive_url(current_doc_id)
print(f"Processing document {current_doc_id} -> {normalized_doc_id}")
# Check for duplicates
if normalized_doc_id in all_normalized_doc_ids:
# print(f"Deleting duplicate document {current_doc_id}")
delete_document_from_db(current_doc_id, index_name)
continue
all_normalized_doc_ids.add(normalized_doc_id)
# If the document ID already doesn't have query parameters, skip it
if current_doc_id == normalized_doc_id:
# print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters")
continue
try:
# Update both database and Vespa in order
# Database first to ensure consistency
update_document_id_in_database(
current_doc_id, normalized_doc_id, index_name
)
# For Vespa, we can now use the original document IDs since we're using contains matching
update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id)
updated_count += 1
# print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}")
except Exception as e:
print(f"Failed to update document {current_doc_id}: {e}")
if isinstance(e, HTTPStatusError):
print(f"HTTPStatusError: {e}")
print(f"Response: {e.response.text}")
print(f"Status: {e.response.status_code}")
print(f"Headers: {e.response.headers}")
print(f"Request: {e.request.url}")
print(f"Request headers: {e.request.headers}")
# Note: Rollback is complex with copy-and-swap approach since the old document is already deleted
# In case of failure, manual intervention may be required
# Continue with other documents instead of failing the entire migration
continue
logger.info(f"Migration complete. Updated {updated_count} Google Drive documents")
def downgrade() -> None:
# this is a one way migration, so no downgrade.
# It wouldn't make sense to store the extra query parameters
# and duplicate documents to allow a reversal.
pass

View File

@@ -144,34 +144,27 @@ def upgrade() -> None:
def downgrade() -> None:
op.execute("TRUNCATE TABLE index_attempt")
conn = op.get_bind()
inspector = sa.inspect(conn)
existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")}
if "input_type" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "source" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
if "connector_specific_config" not in existing_columns:
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
op.add_column(
"index_attempt",
sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False),
)
op.add_column(
"index_attempt",
sa.Column(
"connector_specific_config",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=False,
),
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
@@ -190,12 +183,8 @@ def downgrade() -> None:
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
if "credential_id" in existing_columns:
op.drop_column("index_attempt", "credential_id")
if "connector_id" in existing_columns:
op.drop_column("index_attempt", "connector_id")
op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE")
op.execute("DROP TABLE IF EXISTS credential CASCADE")
op.execute("DROP TABLE IF EXISTS connector CASCADE")
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")
op.drop_table("credential")
op.drop_table("connector")

View File

@@ -1,115 +0,0 @@
"""add_indexing_coordination
Revision ID: 2f95e36923e6
Revises: 0816326d83aa
Create Date: 2025-07-10 16:17:57.762182
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f95e36923e6"
down_revision = "0816326d83aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add database-based coordination fields (replacing Redis fencing)
op.add_column(
"index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"cancellation_requested",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
# Add batch coordination fields (replacing FileStore state)
op.add_column(
"index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt",
sa.Column(
"completed_batches", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"total_failures_batch_level",
sa.Integer(),
nullable=False,
server_default="0",
),
)
op.add_column(
"index_attempt",
sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"),
)
# Progress tracking for stall detection
op.add_column(
"index_attempt",
sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column(
"last_batches_completed_count",
sa.Integer(),
nullable=False,
server_default="0",
),
)
# Heartbeat tracking for worker liveness detection
op.add_column(
"index_attempt",
sa.Column(
"heartbeat_counter", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column(
"last_heartbeat_value", sa.Integer(), nullable=False, server_default="0"
),
)
op.add_column(
"index_attempt",
sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True),
)
# Add index for coordination queries
op.create_index(
"ix_index_attempt_active_coordination",
"index_attempt",
["connector_credential_pair_id", "search_settings_id", "status"],
)
def downgrade() -> None:
# Remove the new index
op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt")
# Remove the new columns
op.drop_column("index_attempt", "last_batches_completed_count")
op.drop_column("index_attempt", "last_progress_time")
op.drop_column("index_attempt", "last_heartbeat_time")
op.drop_column("index_attempt", "last_heartbeat_value")
op.drop_column("index_attempt", "heartbeat_counter")
op.drop_column("index_attempt", "total_chunks")
op.drop_column("index_attempt", "total_failures_batch_level")
op.drop_column("index_attempt", "completed_batches")
op.drop_column("index_attempt", "total_batches")
op.drop_column("index_attempt", "cancellation_requested")
op.drop_column("index_attempt", "celery_task_id")

View File

@@ -9,7 +9,7 @@ Create Date: 2025-06-22 17:33:25.833733
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
revision = "36e9220ab794"
@@ -66,7 +66,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -111,7 +111,7 @@ def upgrade() -> None:
UPDATE "{tenant_id}".kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -15,7 +15,7 @@ from datetime import datetime, timedelta
from onyx.configs.app_configs import DB_READONLY_USER
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
# revision identifiers, used by Alembic.
@@ -80,7 +80,6 @@ def upgrade() -> None:
)
)
op.execute("DROP TABLE IF EXISTS kg_config CASCADE")
op.create_table(
"kg_config",
sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True),
@@ -124,7 +123,6 @@ def upgrade() -> None:
],
)
op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE")
op.create_table(
"kg_entity_type",
sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True),
@@ -158,7 +156,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE")
# Create KGRelationshipType table
op.create_table(
"kg_relationship_type",
@@ -197,7 +194,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE")
# Create KGRelationshipTypeExtractionStaging table
op.create_table(
"kg_relationship_type_extraction_staging",
@@ -231,8 +227,6 @@ def upgrade() -> None:
),
)
op.execute("DROP TABLE IF EXISTS kg_entity CASCADE")
# Create KGEntity table
op.create_table(
"kg_entity",
@@ -287,7 +281,6 @@ def upgrade() -> None:
"ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"]
)
op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE")
# Create KGEntityExtractionStaging table
op.create_table(
"kg_entity_extraction_staging",
@@ -337,7 +330,6 @@ def upgrade() -> None:
["name", "entity_type_id_name"],
)
op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE")
# Create KGRelationship table
op.create_table(
"kg_relationship",
@@ -379,7 +371,6 @@ def upgrade() -> None:
"ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"]
)
op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE")
# Create KGRelationshipExtractionStaging table
op.create_table(
"kg_relationship_extraction_staging",
@@ -423,7 +414,6 @@ def upgrade() -> None:
["source_node", "target_node"],
)
op.execute("DROP TABLE IF EXISTS kg_term CASCADE")
# Create KGTerm table
op.create_table(
"kg_term",
@@ -478,7 +468,7 @@ def upgrade() -> None:
# Create GIN index for clustering and normalization
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams "
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)"
f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.gin_trgm_ops)"
)
op.execute(
"CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams "
@@ -518,7 +508,7 @@ def upgrade() -> None:
-- Set name and name trigrams
NEW.name = name;
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name);
NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name);
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
@@ -563,7 +553,7 @@ def upgrade() -> None:
UPDATE kg_entity
SET
name = doc_name,
name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name)
name_trigrams = {POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE}.show_trgm(cleaned_name)
WHERE document_id = NEW.id;
RETURN NEW;
END;

View File

@@ -159,7 +159,7 @@ def _migrate_files_to_postgres() -> None:
# only create external store if we have files to migrate. This line
# makes it so we need to have S3/MinIO configured to run this migration.
external_store = get_s3_file_store()
external_store = get_s3_file_store(db_session=session)
for i, file_id in enumerate(files_to_migrate, 1):
print(f"Migrating file {i}/{total_files}: {file_id}")
@@ -219,7 +219,7 @@ def _migrate_files_to_external_storage() -> None:
# Get database session
bind = op.get_bind()
session = Session(bind=bind)
external_store = get_s3_file_store()
external_store = get_s3_file_store(db_session=session)
# Find all files currently stored in PostgreSQL (lobj_oid is not null)
result = session.execute(
@@ -236,9 +236,6 @@ def _migrate_files_to_external_storage() -> None:
print("No files found in PostgreSQL storage to migrate.")
return
# might need to move this above the if statement when creating a new multi-tenant
# system. VERY extreme edge case.
external_store.initialize()
print(f"Found {total_files} files to migrate from PostgreSQL to external storage.")
_set_tenant_contextvar(session)

View File

@@ -18,13 +18,11 @@ depends_on: None = None
def upgrade() -> None:
op.execute("DROP TABLE IF EXISTS document CASCADE")
op.create_table(
"document",
sa.Column("id", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS chunk CASCADE")
op.create_table(
"chunk",
sa.Column("id", sa.String(), nullable=False),
@@ -45,7 +43,6 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id", "document_store_type"),
)
op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE")
op.create_table(
"deletion_attempt",
sa.Column("id", sa.Integer(), nullable=False),
@@ -87,7 +84,6 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE")
op.create_table(
"document_by_connector_credential_pair",
sa.Column("id", sa.String(), nullable=False),
@@ -110,10 +106,7 @@ def upgrade() -> None:
def downgrade() -> None:
# upstream tables first
op.drop_table("document_by_connector_credential_pair")
op.drop_table("deletion_attempt")
op.drop_table("chunk")
# Alembic op.drop_table() has no "cascade" flag issue raw SQL
op.execute("DROP TABLE IF EXISTS document CASCADE")
op.drop_table("document")

View File

@@ -91,7 +91,7 @@ def export_query_history_task(
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
get_default_file_store(db_session).save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -422,7 +422,7 @@ def connector_permission_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.cc_pair_id}",
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -383,7 +383,7 @@ def connector_external_group_sync_generator_task(
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.cc_pair_id}",
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)

View File

@@ -114,6 +114,7 @@ def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:
def get_usage_report_data(
db_session: Session,
report_display_name: str,
) -> IO:
"""
@@ -127,7 +128,7 @@ def get_usage_report_data(
Returns:
The usage report data.
"""
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
# usage report may be very large, so don't load it all into memory
return file_store.read_file(
file_id=report_display_name, mode="b", use_tempfile=True

View File

@@ -128,14 +128,11 @@ def validate_object_creation_for_user(
target_group_ids: list[int] | None = None,
object_is_public: bool | None = None,
object_is_perm_sync: bool | None = None,
object_is_owned_by_user: bool = False,
object_is_new: bool = False,
) -> None:
"""
All users can create/edit permission synced objects if they don't specify a group
All admin actions are allowed.
Curators and global curators can create public objects.
Prevents other non-admins from creating/editing:
Prevents non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
@@ -146,23 +143,13 @@ def validate_object_creation_for_user(
if not user or user.role == UserRole.ADMIN:
return
# Allow curators and global curators to create public objects
# w/o associated groups IF the object is new/owned by them
if (
object_is_public
and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR]
and (object_is_new or object_is_owned_by_user)
):
return
if object_is_public and user.role == UserRole.BASIC:
detail = "User does not have permission to create public objects"
if object_is_public:
detail = "User does not have permission to create public credentials"
logger.error(detail)
raise HTTPException(
status_code=400,
detail=detail,
)
if not target_group_ids:
detail = "Curators must specify 1+ groups"
logger.error(detail)

View File

@@ -40,28 +40,8 @@ def _get_slim_doc_generator(
)
def _merge_permissions_lists(
permission_lists: list[list[GoogleDrivePermission]],
) -> list[GoogleDrivePermission]:
"""
Merge a list of permission lists into a single list of permissions.
"""
seen_permission_ids: set[str] = set()
merged_permissions: list[GoogleDrivePermission] = []
for permission_list in permission_lists:
for permission in permission_list:
if permission.id not in seen_permission_ids:
merged_permissions.append(permission)
seen_permission_ids.add(permission.id)
return merged_permissions
def get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType,
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
file: GoogleDriveFileType, company_domain: str, drive_service: GoogleDriveService
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
@@ -82,28 +62,11 @@ def get_external_access_for_raw_gdrive_file(
GoogleDrivePermission.from_drive_permission(p) for p in permissions
]
elif permission_ids:
def _get_permissions(
drive_service: GoogleDriveService,
) -> list[GoogleDrivePermission]:
return get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
permissions_list = _get_permissions(
retriever_drive_service or admin_drive_service
permissions_list = get_permissions_by_ids(
drive_service=drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
if len(permissions_list) != len(permission_ids) and retriever_drive_service:
logger.warning(
f"Failed to get all permissions for file {doc_id} with retriever service, "
"trying admin service"
)
backup_permissions_list = _get_permissions(admin_drive_service)
permissions_list = _merge_permissions_lists(
[permissions_list, backup_permissions_list]
)
folder_ids_to_inherit_permissions_from: set[str] = set()
user_emails: set[str] = set()

View File

@@ -44,17 +44,11 @@ def _get_all_folders(
TODO: tweak things so we can fetch deltas.
"""
MAX_FAILED_PERCENTAGE = 0.5
all_folders: list[FolderInfo] = []
seen_folder_ids: set[str] = set()
def _get_all_folders_for_user(
google_drive_connector: GoogleDriveConnector,
skip_folders_without_permissions: bool,
user_email: str,
) -> None:
"""Helper to get folders for a specific user + update shared seen_folder_ids"""
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
drive_service = get_drive_service(
google_drive_connector.creds,
user_email,
@@ -104,20 +98,6 @@ def _get_all_folders(
)
)
failed_count = 0
user_emails = google_drive_connector._get_all_user_emails()
for user_email in user_emails:
try:
_get_all_folders_for_user(
google_drive_connector, skip_folders_without_permissions, user_email
)
except Exception:
logger.exception(f"Error getting folders for user {user_email}")
failed_count += 1
if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails):
raise RuntimeError("Too many failed folder fetches during group sync")
return all_folders

View File

@@ -134,14 +134,15 @@ def ee_fetch_settings() -> EnterpriseSettings:
def put_logo(
file: UploadFile,
is_logotype: bool = False,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
upload_logo(file=file, is_logotype=is_logotype)
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
@@ -157,7 +158,7 @@ def fetch_logo_helper(db_session: Session) -> Response:
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")

View File

@@ -6,6 +6,7 @@ from typing import IO
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
@@ -98,7 +99,9 @@ def guess_file_type(filename: str) -> str:
return "application/octet-stream"
def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
def upload_logo(
db_session: Session, file: UploadFile | str, is_logotype: bool = False
) -> bool:
content: IO[Any]
if isinstance(file, str):
@@ -126,7 +129,7 @@ def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool:
display_name = file.filename
file_type = file.content_type or "image/jpeg"
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
file_store.save_file(
content=content,
display_name=display_name,

View File

@@ -1,6 +1,5 @@
import re
from typing import cast
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
@@ -74,7 +73,6 @@ def _get_final_context_doc_indices(
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
chat_session_id: UUID,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
@@ -218,8 +216,6 @@ def _convert_packet_stream_to_response(
if answer:
response.answer_citationless = remove_answer_citations(answer)
response.chat_session_id = chat_session_id
return response
@@ -241,36 +237,13 @@ def handle_simplified_chat_message(
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
# Handle chat session creation if chat_session_id is not provided
if chat_message_req.chat_session_id is None:
if chat_message_req.persona_id is None:
raise HTTPException(
status_code=400,
detail="Either chat_session_id or persona_id must be provided",
)
# Create a new chat session with the provided persona_id
try:
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave empty for simple API
user_id=user.id if user else None,
persona_id=chat_message_req.persona_id,
)
chat_session_id = new_chat_session.id
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
else:
chat_session_id = chat_message_req.chat_session_id
try:
parent_message, _ = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_session_id, db_session=db_session
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
if (
@@ -285,7 +258,7 @@ def handle_simplified_chat_message(
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session_id,
chat_session_id=chat_message_req.chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
@@ -310,7 +283,7 @@ def handle_simplified_chat_message(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session_id)
return _convert_packet_stream_to_response(packets)
@router.post("/send-message-simple-with-history")
@@ -430,4 +403,4 @@ def handle_send_message_simple_with_history(
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets, chat_session.id)
return _convert_packet_stream_to_response(packets)

View File

@@ -41,13 +41,11 @@ class DocumentSearchRequest(ChunkContext):
class BasicCreateChatMessageRequest(ChunkContext):
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
"""Before creating messages, be sure to create a chat_session and get an id
Note, for simplicity this option only allows for a single linear chain of messages
"""
chat_session_id: UUID | None = None
# Optional persona_id to create a new chat session if chat_session_id is not provided
persona_id: int | None = None
chat_session_id: UUID
# New message contents
message: str
# Defaults to using retrieval with no additional filters
@@ -64,12 +62,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
raise ValueError("Either chat_session_id or persona_id must be provided")
return self
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -179,9 +171,6 @@ class ChatBasicResponse(BaseModel):
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
# Chat session ID for tracking conversation continuity
chat_session_id: UUID | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -358,7 +358,7 @@ def get_query_history_export_status(
# If task is None, then it's possible that the task has already finished processing.
# Therefore, we should then check if the export file has already been stored inside of the file-store.
# If that *also* doesn't exist, then we can return a 404.
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
report_name = construct_query_history_report_name(request_id)
has_file = file_store.has_file(
@@ -385,7 +385,7 @@ def download_query_history_csv(
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
report_name = construct_query_history_report_name(request_id)
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
has_file = file_store.has_file(
file_id=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,

View File

@@ -53,7 +53,7 @@ def read_usage_report(
db_session: Session = Depends(get_session),
) -> Response:
try:
file = get_usage_report_data(report_name)
file = get_usage_report_data(db_session, report_name)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -112,7 +112,7 @@ def create_new_usage_report(
period: tuple[datetime, datetime] | None,
) -> UsageReportMetadata:
report_id = str(uuid.uuid4())
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
messages_file_id = generate_chat_messages_report(
db_session, file_store, report_id, period

View File

@@ -200,10 +200,10 @@ def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None:
store_ee_settings(final_enterprise_settings)
def _seed_logo(logo_path: str | None) -> None:
def _seed_logo(db_session: Session, logo_path: str | None) -> None:
if logo_path:
logger.notice("Uploading logo")
upload_logo(file=logo_path)
upload_logo(db_session=db_session, file=logo_path)
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
@@ -245,7 +245,7 @@ def seed_db() -> None:
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)
_seed_logo(seed_config.seeded_logo_path)
_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
_seed_analytics_script(seed_config)

View File

@@ -10,12 +10,10 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
@@ -49,26 +47,6 @@ def gate_product(
return ProductGatingResponse(updated=False, error=str(e))
@router.post("/product-gating/full-sync")
def gate_product_full_sync(
product_gating_request: ProductGatingFullSyncRequest,
_: None = Depends(control_plane_dep),
) -> ProductGatingResponse:
"""
Bulk operation to overwrite the entire gated tenant set.
This replaces all currently gated tenants with the provided list.
Gated tenants are not available to access the product and will be
directed to the billing page when their subscription has ended.
"""
try:
overwrite_full_gated_set(product_gating_request.gated_tenant_ids)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate products during full sync")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),

View File

@@ -19,10 +19,6 @@ class ProductGatingRequest(BaseModel):
application_status: ApplicationStatus
class ProductGatingFullSyncRequest(BaseModel):
gated_tenant_ids: list[str]
class SubscriptionStatusResponse(BaseModel):
subscribed: bool

View File

@@ -16,6 +16,10 @@ logger = setup_logger()
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
# Store the full status
status_key = f"tenant:{tenant_id}:status"
redis_client.set(status_key, status.value)
# Maintain the GATED_ACCESS set
if status == ApplicationStatus.GATED_ACCESS:
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
@@ -42,25 +46,6 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
raise
def overwrite_full_gated_set(tenant_ids: list[str]) -> None:
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
pipeline = redis_client.pipeline()
# using pipeline doesn't automatically add the tenant_id prefix
full_gated_set_key = f"{ONYX_CLOUD_TENANT_ID}:{GATED_TENANTS_KEY}"
# Clear the existing set
pipeline.delete(full_gated_set_key)
# Add all tenant IDs to the set and set their status
for tenant_id in tenant_ids:
pipeline.sadd(full_gated_set_key, tenant_id)
# Execute all commands at once
pipeline.execute()
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -203,8 +203,6 @@ def generate_simple_sql(
if state.kg_entity_temp_view_name is None:
raise ValueError("kg_entity_temp_view_name is not set")
sql_statement_display: str | None = None
## STEP 3 - articulate goals
stream_write_step_activities(writer, _KG_STEP_NR)
@@ -383,18 +381,7 @@ def generate_simple_sql(
raise e
# display sql statement with view names replaced by general view names
sql_statement_display = sql_statement.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
sql_statement_display = sql_statement_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
logger.debug(f"A3 - sql_statement after correction: {sql_statement_display}")
logger.debug(f"A3 - sql_statement after correction: {sql_statement}")
# Get SQL for source documents
@@ -422,20 +409,7 @@ def generate_simple_sql(
"relationship_table", rel_temp_view
)
if source_documents_sql:
source_documents_sql_display = source_documents_sql.replace(
state.kg_doc_temp_view_name, "<your_allowed_docs_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_rel_temp_view_name, "<your_relationship_view_name>"
)
source_documents_sql_display = source_documents_sql_display.replace(
state.kg_entity_temp_view_name, "<your_entity_view_name>"
)
else:
source_documents_sql_display = "(No source documents SQL generated)"
logger.debug(f"A3 source_documents_sql: {source_documents_sql_display}")
logger.debug(f"A3 source_documents_sql: {source_documents_sql}")
scalar_result = None
query_results = None
@@ -461,13 +435,7 @@ def generate_simple_sql(
rows = result.fetchall()
query_results = [dict(row._mapping) for row in rows]
except Exception as e:
# TODO: raise error on frontend
logger.error(f"Error executing SQL query: {e}")
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
raise e
@@ -491,14 +459,8 @@ def generate_simple_sql(
for source_document_result in query_source_document_results
]
except Exception as e:
# No stopping here, the individualized SQL query is not mandatory
# TODO: raise error on frontend
drop_views(
allowed_docs_view_name=doc_temp_view,
kg_relationships_view_name=rel_temp_view,
kg_entity_view_name=ent_temp_view,
)
logger.error(f"Error executing Individualized SQL query: {e}")
else:
@@ -531,11 +493,11 @@ def generate_simple_sql(
if reasoning:
stream_write_step_answer_explicit(writer, step_nr=_KG_STEP_NR, answer=reasoning)
if sql_statement_display:
if main_sql_statement:
stream_write_step_answer_explicit(
writer,
step_nr=_KG_STEP_NR,
answer=f" \n Generated SQL: {sql_statement_display}",
answer=f" \n Generated SQL: {main_sql_statement}",
)
stream_close_step_answer(writer, _KG_STEP_NR)

View File

@@ -51,6 +51,7 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
@@ -127,7 +128,6 @@ def choose_tool(
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
)
override_kwargs.original_query = agent_config.inputs.prompt_builder.raw_user_query
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder

View File

@@ -174,6 +174,7 @@ def get_test_config(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
@@ -197,7 +198,7 @@ def get_test_config(
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
document_pruning_config=search_tool_config.document_pruning_config,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,

View File

@@ -24,14 +24,13 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatt
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -40,7 +39,6 @@ from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -94,13 +92,7 @@ def on_task_prerun(
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
# Reset any per-task logging context so that prefixes (e.g. pruning_ctx)
# from a previous task executed in the same worker process do not leak
# into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt]
# prefixes observed when a pruning task finishes and an indexing task
# runs in the same process.
LoggerContextVars.reset()
pass
def on_task_postrun(
@@ -153,11 +145,8 @@ def on_task_postrun(
r = get_redis_client(tenant_id=tenant_id)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others below
if task_id.startswith(DOCUMENT_SYNC_PREFIX):
r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id)
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
return
if task_id.startswith(RedisDocumentSet.PREFIX):
@@ -481,8 +470,7 @@ class TenantContextFilter(logging.Filter):
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id:
# Match the 8 character tenant abbreviation used in OnyxLoggingAdapter
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8]
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
record.name = f"[t:{tenant_id}]"
else:
record.name = ""

View File

@@ -1,102 +0,0 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -12,7 +12,7 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -21,7 +21,7 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@@ -60,7 +60,7 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME)
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
@@ -108,6 +108,6 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -116,6 +116,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.user_file_folder_sync",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -9,7 +9,6 @@ from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.result import AsyncResult
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -19,7 +18,9 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
@@ -28,7 +29,9 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -153,10 +156,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
# NOTE: we want to remove the `Redis*` classes, prefer to just have functions
# This is the preferred way to do this going forward
reset_document_sync(r)
RedisGlobalConnectorCredentialPair.reset_all(r)
RedisDocumentSet.reset_all(r)
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
@@ -167,50 +167,24 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
# This uses database coordination instead of Redis fencing
with get_session_with_current_tenant() as db_session:
# Get potentially orphaned attempts (those with active status and task IDs)
potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids(
db_session
)
for attempt_id in potentially_orphaned_ids:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
# handle case where not started or docfetching is done but indexing is not
if (
not attempt
or not attempt.celery_task_id
or attempt.total_batches is not None
):
if not attempt:
continue
# Check if the Celery task actually exists
try:
result: AsyncResult = AsyncResult(attempt.celery_task_id)
# If the task is not in PENDING state, it exists in Celery
if result.state != "PENDING":
continue
# Task is orphaned - mark as failed
failure_reason = (
f"Orphaned index attempt found on startup - Celery task not found: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id} "
f"celery_task_id={attempt.celery_task_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
except Exception:
# If we can't check the task status, be conservative and continue
logger.warning(
f"Could not verify Celery task status on startup for attempt {attempt.id}, "
f"task_id={attempt.celery_task_id}"
)
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
logger.exception(
f"Marking attempt {attempt.id} as canceled due to validation error 2"
)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect
@@ -317,7 +291,7 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.docprocessing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.periodic",
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.shared",

View File

@@ -26,7 +26,7 @@ def celery_get_unacked_length(r: Redis) -> int:
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queues are "prefetched", so this gives
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()

View File

@@ -1,22 +0,0 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Docfetching worker configuration
worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,5 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY
from onyx.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -24,6 +24,6 @@ task_acks_late = shared_config.task_acks_late
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -100,6 +100,24 @@ beat_task_templates: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,

View File

@@ -40,11 +40,9 @@ from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -71,21 +69,13 @@ def revoke_tasks_blocking_deletion(
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=redis_connector.cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
and recent_index_attempts[0].celery_task_id
):
app.control.revoke(recent_index_attempts[0].celery_task_id)
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {recent_index_attempts[0].celery_task_id}."
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
@@ -291,16 +281,8 @@ def try_generate_document_cc_pair_cleanup_tasks(
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
recent_index_attempts = get_recent_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings.id,
limit=1,
db_session=db_session,
)
if (
recent_index_attempts
and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS
):
redis_connector_index = redis_connector.new_index(search_settings.id)
if redis_connector_index.fenced:
raise TaskDependencyError(
"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "

View File

@@ -1,637 +0,0 @@
import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus
from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
def _verify_indexing_attempt(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
) -> None:
"""
Verify that the indexing attempt exists and is in the correct state.
"""
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if attempt.connector_credential_pair_id != cc_pair_id:
raise SimpleJobException(
f"docfetching_task - CC pair mismatch: "
f"expected={cc_pair_id} actual={attempt.connector_credential_pair_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.search_settings_id != search_settings_id:
raise SimpleJobException(
f"docfetching_task - Search settings mismatch: "
f"expected={search_settings_id} actual={attempt.search_settings_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
if attempt.status not in [
IndexingStatus.NOT_STARTED,
IndexingStatus.IN_PROGRESS,
]:
raise SimpleJobException(
f"docfetching_task - Invalid attempt status: "
f"attempt_id={index_attempt_id} status={attempt.status}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
# Check for cancellation
if IndexingCoordination.check_cancellation_requested(
db_session, index_attempt_id
):
raise SimpleJobException(
f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
logger.info(
f"docfetching_task - IndexAttempt verified: "
f"attempt_id={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
def docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
"""
This function is run in a SimpleJob as a new process. It is responsible for validating
some stuff, but basically it just calls run_indexing_entrypoint.
NOTE: if an exception is raised out of this task, the primary worker will detect
that the task transitioned to a "READY" state but the generator_complete_key doesn't exist.
This will cause the primary worker to abort the indexing attempt and clean up.
"""
# Start heartbeat for this indexing attempt
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
try:
_docfetching_task(
app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id
)
finally:
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
def _docfetching_task(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
) -> None:
# Since connector_indexing_proxy_task spawns a new process using this function as
# the entrypoint, we init Sentry here.
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# Verify the indexing attempt exists and is valid
# This replaces the Redis fence payload waiting
_verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id)
try:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
redis_connector,
)
logger.info(
f"Indexing spawned task running entrypoint: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
app,
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
callback=callback,
)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# special bulletproofing ... truncate long exception messages
# for exception types that require more args, this will fail
# thus the try/except
try:
sanitized_e = type(e)(str(e)[:1024])
sanitized_e.__traceback__ = e.__traceback__
raise sanitized_e
except Exception:
raise e
logger.info(
f"Indexing spawned task finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
os._exit(0) # ensure process exits cleanly
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@shared_task(
name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def docfetching_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
) -> None:
"""
This task is the entrypoint for the full indexing pipeline, which is composed of two tasks:
docfetching and docprocessing.
This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task.
This task spawns a new process for a new scheduled index attempt. That
new process (which runs the docfetching_task function) does the following:
1) determines parameters of the indexing attempt (which connector indexing function to run,
start and end time, from prev checkpoint or not), then run that connector. Specifically,
connectors are responsible for reading data from an outside source and converting it to Onyx documents.
At the moment these two steps (reading external data and converting to an Onyx document)
are not parallelized in most connectors; that's a subject for future work.
Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned
to process it. docprocessing involves the steps listed below.
2) upserts documents to postgres (index_doc_batch_prepare)
3) chunks each document (optionally adds context for contextual rag)
4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server
5) write chunks to vespa (write_chunks_to_vector_db_with_backoff)
6) update document and indexing metadata in postgres
7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes
all locally stored IDs missing from the most recently pulled document ID list
Some important notes:
Invariants:
- docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog.
The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed
if it is not making progress.
- All docprocessing tasks are spawned by a docfetching task.
- all docfetching tasks, docprocessing tasks, and document batches in the file store are
associated with a specific index attempt.
- the index attempt status is the source of truth for what is currently happening with the index attempt.
It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible.
How we deal with failures/ partial indexing:
- non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run
- checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run
Misc:
- most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it
- Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness
- progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed
- TODO: task level timeouts (i.e. a connector stuck in an infinite loop)
Comments below are from the old version and some may no longer be valid.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
Some more Richard notes:
celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
# TODO: remove dependence on Redis
start = time.monotonic()
result = SimpleJobResult()
ctx = DocProcessingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}")
job = client.submit(
docfetching_task,
self.app,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
eager_load_cc_pair=True,
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
while True:
sleep(5)
time.monotonic()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to termination signal"
)
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)

View File

@@ -1,36 +0,0 @@
import threading
from sqlalchemy import update
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import IndexAttempt
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
"""Start a heartbeat thread for the given index attempt"""
stop_event = threading.Event()
def heartbeat_loop() -> None:
while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL):
try:
with get_session_with_current_tenant() as db_session:
db_session.execute(
update(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
)
db_session.commit()
except Exception:
# Silently continue if heartbeat fails
pass
thread = threading.Thread(target=heartbeat_loop, daemon=True)
thread.start()
return thread, stop_event
def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None:
"""Stop the heartbeat thread"""
stop_event.set()
thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,10 @@
import time
from datetime import datetime
from datetime import timezone
from uuid import uuid4
from typing import Any
from typing import cast
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
@@ -10,6 +12,8 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
@@ -17,19 +21,27 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import setup_logger
@@ -38,6 +50,54 @@ logger = setup_logger()
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
@@ -63,9 +123,10 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
if self.redis_connector.stop.fenced:
return True
return False
def progress(self, tag: str, amount: int) -> None:
"""Amount isn't used yet."""
@@ -110,28 +171,186 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
raise
# NOTE: we're in the process of removing all fences from indexing; this will
# eventually no longer be used. For now, it is used only for connector pausing.
class IndexingCallback(IndexingHeartbeatInterface):
class IndexingCallback(IndexingCallbackBase):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
self.redis_connector = redis_connector
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
# included to satisfy old interface
def progress(self, tag: str, amount: int) -> None:
pass
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
# NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed
# as they are no longer needed with database-based coordination. The new validation is
# handled by validate_active_indexing_attempts in the main indexing tasks module.
def validate_indexing_fence(
tenant_id: str,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
"""Validates all indexing fences for this tenant ... aka makes sure
indexing tasks sent to celery are still in flight.
"""
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# Use replica for this because the worst thing that happens
# is that we don't run the validation on this pass
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
lock_beat.reacquire()
return
def is_in_repeated_error_state(
@@ -195,12 +414,10 @@ def should_index(
)
# uncomment for debugging
task_logger.info(
f"_should_index: "
f"cc_pair={cc_pair.id} "
f"connector={cc_pair.connector_id} "
f"refresh_freq={connector.refresh_freq}"
)
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
@@ -300,7 +517,7 @@ def should_index(
return True
def try_creating_docfetching_task(
def try_creating_indexing_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
@@ -314,11 +531,10 @@ def try_creating_docfetching_task(
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
@@ -331,42 +547,61 @@ def try_creating_docfetching_task(
if not acquired:
return None
index_attempt_id = None
redis_connector_index: RedisConnectorIndex
try:
# Basic status checks
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
if index_attempt_id is None:
# Another indexing attempt is already running
return None
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
else OnyxCeleryQueues.CONNECTOR_INDEXING
)
# Send the task to Celery
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -378,18 +613,14 @@ def try_creating_docfetching_task(
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
@@ -397,10 +628,9 @@ def try_creating_docfetching_task(
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
mark_attempt_failed(index_attempt_id, db_session)
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():

View File

@@ -1,110 +0,0 @@
from enum import Enum
from pydantic import BaseModel
class DocProcessingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None

View File

@@ -147,7 +147,7 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"docprocessing_queue_length": "docprocessing",
"indexing_queue_length": "indexing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
@@ -882,13 +882,7 @@ def monitor_celery_queues_helper(
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
n_user_files_indexing = celery_get_queue_length(
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
)
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
@@ -902,20 +896,14 @@ def monitor_celery_queues_helper(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
n_docprocessing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.DOCPROCESSING, r_celery
n_indexing_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"docfetching={n_docfetching} "
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
f"docprocessing={n_docprocessing} "
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
f"user_files_indexing={n_user_files_indexing} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(n_indexing_prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "

View File

@@ -22,7 +22,7 @@ from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -138,11 +138,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
# if we've never indexed, we can't prune
return False
# if never pruned, use the connector creation time. We could also
# compute the completion time of the first successful index attempt, but
# that is a reasonably heavy operation. This is a reasonable approximation —
# in the worst case, we'll prune a little bit earlier than we should.
last_pruned = cc_pair.connector.time_created
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
@@ -176,9 +173,6 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
task_logger.info("Checking for pruning due")
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
@@ -193,18 +187,15 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
cc_pair_id=cc_pair_id,
)
if not cc_pair:
logger.error(f"CC pair not found: {cc_pair_id}")
continue
if not _is_pruning_due(cc_pair):
logger.info(f"CC pair not due for pruning: {cc_pair_id}")
continue
payload_id = try_creating_prune_generator_task(
self.app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
logger.info(f"Pruning not created: {cc_pair_id}")
continue
task_logger.info(
@@ -273,8 +264,6 @@ def try_creating_prune_generator_task(
is used to trigger prunes immediately, e.g. via the web ui.
"""
logger.info(f"try_creating_prune_generator_task: cc_pair={cc_pair.id}")
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not ALLOW_SIMULTANEOUS_PRUNING:
@@ -298,30 +287,18 @@ def try_creating_prune_generator_task(
try:
# skip pruning if already pruning
if redis_connector.prune.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} already pruning"
)
return None
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
)
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} permissions sync running"
)
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting"
)
return None
# add a long running generator task to the queue
@@ -464,7 +441,7 @@ def connector_pruning_generator_task(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}",
OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)

View File

@@ -1,178 +0,0 @@
import time
from typing import cast
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document import construct_document_id_select_by_needs_sync
from onyx.db.document import count_documents_by_needs_sync
from onyx.utils.logger import setup_logger
# Redis keys for document sync tracking
DOCUMENT_SYNC_PREFIX = "documentsync"
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
logger = setup_logger()
def is_document_sync_fenced(r: Redis) -> bool:
"""Check if document sync tasks are currently in progress."""
return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY))
def get_document_sync_payload(r: Redis) -> int | None:
"""Get the initial number of tasks that were created."""
bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY)
if bytes_result is None:
return None
return int(cast(int, bytes_result))
def get_document_sync_remaining(r: Redis) -> int:
"""Get the number of tasks still pending completion."""
return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY))
def set_document_sync_fence(r: Redis, payload: int | None) -> None:
"""Set up the fence and register with active fences."""
if payload is None:
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
return
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
def delete_document_sync_taskset(r: Redis) -> None:
"""Clear the document sync taskset."""
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
def reset_document_sync(r: Redis) -> None:
"""Reset all document sync tracking data."""
r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
r.delete(DOCUMENT_SYNC_TASKSET_KEY)
r.delete(DOCUMENT_SYNC_FENCE_KEY)
def generate_document_sync_tasks(
r: Redis,
max_tasks: int,
celery_app: Celery,
db_session: Session,
lock: RedisLock,
tenant_id: str,
) -> tuple[int, int]:
"""Generate sync tasks for all documents that need syncing.
Args:
r: Redis client
max_tasks: Maximum number of tasks to generate
celery_app: Celery application instance
db_session: Database session
lock: Redis lock for coordination
tenant_id: Tenant identifier
Returns:
tuple[int, int]: (tasks_generated, total_docs_found)
"""
last_lock_time = time.monotonic()
num_tasks_sent = 0
num_docs = 0
# Get all documents that need syncing
stmt = construct_document_id_select_by_needs_sync()
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
# Reacquire lock periodically to prevent timeout
if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# Create a unique task ID
custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}"
# Add to the tracking taskset in Redis BEFORE creating the celery task
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
# Create the Celery task
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
num_tasks_sent += 1
if num_tasks_sent >= max_tasks:
break
return num_tasks_sent, num_docs
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
if is_document_sync_fenced(r):
return None
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
logger.info("No stale documents found. Skipping sync tasks generation.")
return None
logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
)
logger.info("generate_document_sync_tasks starting for all documents.")
# Generate all tasks in one pass
result = generate_document_sync_tasks(
r, max_tasks, celery_app, db_session, lock_beat, tenant_id
)
if result is None:
return None
tasks_generated, total_docs = result
if tasks_generated >= max_tasks:
logger.info(
f"generate_document_sync_tasks reached the task generation limit: "
f"tasks_generated={tasks_generated} max_tasks={max_tasks}"
)
else:
logger.info(
f"generate_document_sync_tasks finished for all documents. "
f"tasks_generated={tasks_generated} total_docs_found={total_docs}"
)
set_document_sync_fence(r, tasks_generated)
return tasks_generated

View File

@@ -20,19 +20,14 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload
from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining
from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync
from onyx.background.celery.tasks.vespa.document_sync import (
try_generate_stale_document_sync_tasks,
)
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import count_documents_by_needs_sync
from onyx.db.document import get_document
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import delete_document_set
@@ -52,6 +47,10 @@ from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -167,11 +166,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
continue
key_str = key_bytes.decode("utf-8")
# NOTE: removing the "Redis*" classes, prefer to just have functions to
# do these things going forward. In short, things should generally be like the doc
# sync task rather than the others
if key_str == DOCUMENT_SYNC_FENCE_KEY:
monitor_document_sync_taskset(r)
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
@@ -207,6 +203,82 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
return True
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
) -> int | None:
# the fence is up, do nothing
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
if redis_global_ccpair.fenced:
return None
redis_global_ccpair.delete_taskset()
# add tasks to celery and build up the task set to monitor in redis
stale_doc_count = count_documents_by_needs_sync(db_session)
if stale_doc_count == 0:
return None
task_logger.info(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
tasks_remaining = max_tasks
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
continue
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += result[0]
tasks_remaining -= result[0]
if tasks_remaining <= 0:
break
if tasks_remaining <= 0:
task_logger.info(
f"RedisConnector.generate_tasks reached the task generation limit: "
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
)
else:
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
redis_global_ccpair.set_fence(total_tasks_generated)
return total_tasks_generated
def try_generate_document_set_sync_tasks(
celery_app: Celery,
document_set_id: int,
@@ -361,18 +433,19 @@ def try_generate_user_group_sync_tasks(
return tasks_generated
def monitor_document_sync_taskset(r: Redis) -> None:
initial_count = get_document_sync_payload(r)
def monitor_connector_taskset(r: Redis) -> None:
redis_global_ccpair = RedisGlobalConnectorCredentialPair(r)
initial_count = redis_global_ccpair.payload
if initial_count is None:
return
remaining = get_document_sync_remaining(r)
remaining = redis_global_ccpair.get_remaining()
task_logger.info(
f"Document sync progress: remaining={remaining} initial={initial_count}"
f"Stale document sync progress: remaining={remaining} initial={initial_count}"
)
if remaining == 0:
reset_document_sync(r)
task_logger.info(f"Successfully synced all documents. count={initial_count}")
redis_global_ccpair.reset()
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(

View File

@@ -1,18 +0,0 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docprocessing import celery_app
return celery_app
app = get_app()

View File

@@ -10,7 +10,7 @@ set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.docfetching import celery_app
from onyx.background.celery.apps.indexing import celery_app
return celery_app

View File

@@ -33,7 +33,7 @@ def save_checkpoint(
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
file_store.save_file(
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
@@ -52,11 +52,11 @@ def save_checkpoint(
def load_checkpoint(
index_attempt_id: int, connector: BaseConnector
db_session: Session, index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointedConnector):
@@ -71,7 +71,7 @@ def get_latest_valid_checkpoint(
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> tuple[ConnectorCheckpoint, bool]:
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
@@ -83,7 +83,7 @@ def get_latest_valid_checkpoint(
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# where we make no progress. Only do this if we have had at least
# _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts.
if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
had_any_progress = False
for candidate in checkpoint_candidates:
if (
@@ -99,7 +99,7 @@ def get_latest_valid_checkpoint(
f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return connector.build_dummy_checkpoint(), False
return connector.build_dummy_checkpoint()
# filter out any candidates that don't meet the criteria
checkpoint_candidates = [
@@ -140,10 +140,11 @@ def get_latest_valid_checkpoint(
logger.info(
f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch."
)
return checkpoint, False
return checkpoint
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
@@ -152,14 +153,14 @@ def get_latest_valid_checkpoint(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint."
)
return checkpoint, False
return checkpoint
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
return previous_checkpoint, True
return previous_checkpoint
def get_index_attempts_with_old_checkpoints(
@@ -200,7 +201,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store()
file_store = get_default_file_store(db_session)
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None

View File

@@ -1,4 +1,3 @@
import sys
import time
import traceback
from collections import defaultdict
@@ -6,7 +5,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from celery import Celery
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.access.access import source_should_fetch_permissions_during_indexing
@@ -19,25 +18,18 @@ from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -57,16 +49,13 @@ from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
@@ -79,7 +68,7 @@ from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(propagate=False)
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
@@ -157,10 +146,6 @@ def _get_connector_runner(
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
@@ -195,11 +180,25 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session,
cc_pair_id: int,
search_settings_status: IndexModelStatus,
index_attempt_id: int,
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
@@ -207,34 +206,27 @@ def _check_connector_and_attempt_status(
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
cc_pair_id,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.")
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and search_settings_status != IndexModelStatus.FUTURE
and ctx.search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}")
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status == IndexingStatus.CANCELED:
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt is not running, status is {index_attempt_loop.status}"
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
if index_attempt_loop.celery_task_id is None:
raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id")
# TODO: delete from here if ends up unused
def _check_failure_threshold(
total_failures: int,
document_count: int,
@@ -265,9 +257,6 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >1 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -282,12 +271,7 @@ def _run_indexing(
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
@@ -308,7 +292,7 @@ def _run_indexing(
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = DocExtractionContext(
ctx = RunIndexingContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
@@ -333,7 +317,6 @@ def _run_indexing(
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
@@ -401,6 +384,19 @@ def _run_indexing(
httpx_client=HttpxPool.get("vespa"),
)
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
@@ -420,9 +416,7 @@ def _run_indexing(
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -445,7 +439,7 @@ def _run_indexing(
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint, _ = get_latest_valid_checkpoint(
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
@@ -502,10 +496,7 @@ def _run_indexing(
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
db_session_temp, ctx, index_attempt_id
)
# save record of any failures at the connector level
@@ -563,16 +554,7 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
index_pipeline_result = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
@@ -833,7 +815,6 @@ def _run_indexing(
def run_indexing_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
connector_credential_pair_id: int,
@@ -851,6 +832,7 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
@@ -864,514 +846,18 @@ def run_indexing_entrypoint(
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Docfetching starting{tenant_str}: "
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
connector_document_extraction(
app,
index_attempt_id,
attempt.connector_credential_pair_id,
attempt.search_settings_id,
tenant_id,
callback,
)
logger.info(
f"Docfetching finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
def connector_document_extraction(
app: Celery,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Extract documents from connector and queue them for indexing pipeline processing.
This is the first part of the split indexing process that runs the connector
and extracts documents, storing them in the filestore for later processing.
"""
start_time = time.monotonic()
logger.info(
f"Document extraction starting: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"tenant={tenant_id}"
)
# Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint)
batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt = None
last_batch_num = 0 # used to continue from checkpointing
# comes from _run_indexing
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found")
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
if index_attempt.search_settings is None:
raise ValueError("Search settings must be set for indexing")
# Clear the indexing trigger if it was set, to prevent duplicate indexing attempts
if index_attempt.connector_credential_pair.indexing_trigger is not None:
logger.info(
"Clearing indexing trigger: "
f"cc_pair={index_attempt.connector_credential_pair.id} "
f"trigger={index_attempt.connector_credential_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
index_attempt.connector_credential_pair.id, None, db_session
)
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
)
should_fetch_permissions_during_indexing = (
index_attempt.connector_credential_pair.access_type == AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
)
# Set up time windows for polling
last_successful_index_poll_range_end = (
earliest_index_time
if from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=cc_pair_id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
db_session=db_session,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# set time range in db
index_attempt.poll_range_start = window_start
index_attempt.poll_range_end = window_end
db_session.commit()
# TODO: maybe memory tracer here
# Set up connector runner
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
logger.info(
f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run"
)
batch_storage.cleanup_all_batches()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
logger.info(
f"Getting latest valid checkpoint for index attempt {index_attempt_id}"
)
checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint(
db_session=db_session,
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# checkpoint resumption OR the connector already finished.
if (
isinstance(connector_runner.connector, CheckpointedConnector)
and resuming_from_checkpoint
) or (
most_recent_attempt
and most_recent_attempt.total_batches is not None
and not checkpoint.has_more
):
reissued_batch_count, completed_batches = reissue_old_batches(
batch_storage,
index_attempt_id,
cc_pair_id,
tenant_id,
app,
most_recent_attempt,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
db_session.commit()
else:
logger.info(
f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run"
)
# for non-checkpointed connectors, throw out batches from previous unsuccessful attempts
# because we'll be getting those documents again anyways.
batch_storage.cleanup_all_batches()
# Save initial checkpoint
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
try:
batch_num = last_batch_num # starts at 0 if no last batch
total_doc_batches_queued = 0
total_failures = 0
document_count = 0
# Main extraction loop
while checkpoint.has_more:
logger.info(
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback and callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# will exception if the connector/index attempt is marked as paused/failed
with get_session_with_current_tenant() as db_session_tmp:
_check_connector_and_attempt_status(
db_session_tmp,
cc_pair_id,
index_attempt.search_settings.status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session:
create_index_attempt_error(
index_attempt_id,
cc_pair_id,
failure,
db_session,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# Save checkpoint if provided
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing task, so if no batch we can just continue
if not document_batch:
continue
# Clean documents and create batch
doc_batch_cleaned = strip_null_characters(document_batch)
batch_description = []
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
memory_tracer.increment_and_maybe_trace()
# Store documents in storage
batch_storage.store_batch(batch_num, doc_batch_cleaned)
# Create processing task data
processing_batch_data = {
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": batch_num, # 0-indexed
}
# Queue document processing task
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
batch_num += 1
total_doc_batches_queued += 1
logger.info(
f"Queued document processing batch: "
f"batch_num={batch_num} "
f"docs={len(doc_batch_cleaned)} "
f"attempt={index_attempt_id}"
)
# Check checkpoint size periodically
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# Save latest checkpoint
# NOTE: checkpointing is used to track which batches have
# been sent to the filestore, NOT which batches have been fully indexed
# as it used to be.
with get_session_with_current_tenant() as db_session:
save_checkpoint(
db_session=db_session,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
elapsed_time = time.monotonic() - start_time
logger.info(
f"Document extraction completed: "
f"attempt={index_attempt_id} "
f"batches_queued={total_doc_batches_queued} "
f"elapsed={elapsed_time:.2f}s"
)
# Set total batches in database to signal extraction completion.
# Used by check_for_indexing to determine if the index attempt is complete.
with get_session_with_current_tenant() as db_session:
IndexingCoordination.set_total_batches(
db_session=db_session,
index_attempt_id=index_attempt_id,
total_batches=batch_num,
)
except Exception as e:
logger.exception(
f"Document extraction failed: "
f"attempt={index_attempt_id} "
f"error={str(e)}"
)
# Do NOT clean up batches on failure; future runs will use those batches
# while docfetching will continue from the saved checkpoint if one exists
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {db_connector.id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=db_connector.id,
credential_id=db_credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
else:
with get_session_with_current_tenant() as db_session_temp:
# don't overwrite attempts that are already failed/canceled for another reason
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if index_attempt and index_attempt.status in [
IndexingStatus.CANCELED,
IndexingStatus.FAILED,
]:
logger.info(
f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed."
)
raise e
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise e
finally:
memory_tracer.stop()
def reissue_old_batches(
batch_storage: DocumentBatchStorage,
index_attempt_id: int,
cc_pair_id: int,
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
old_batches = batch_storage.get_all_batches_for_cc_pair()
batch_storage.update_old_batches_to_new_index_attempt(old_batches)
for batch_id in old_batches:
logger.info(
f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}"
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
app.send_task(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs={
"index_attempt_id": index_attempt_id,
"cc_pair_id": cc_pair_id,
"tenant_id": tenant_id,
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more
# than the last batch created by docfetching regardless of whether the batch
# is still in the filestore waiting for processing or not.
last_batch_num = len(old_batches) + recent_batches
logger.info(
f"Starting from batch {last_batch_num} due to "
f"re-issued batches: {old_batches}, completed batches: {recent_batches}"
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
return len(old_batches), recent_batches

View File

@@ -1,7 +1,12 @@
import csv
import json
import os
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import (
run_basic_graph as run_hackathon_graph,
) # You can create your own graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import ToolCallFinalResult
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
@@ -44,6 +54,190 @@ logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
"""
Calculate the score for a given position.
"""
if pos > max_acceptable_pos:
return 0
elif pos == 1:
return 1
elif pos == 2:
return 0.8
else:
return 4 / (pos + 5)
def _clean_doc_id_link(doc_link: str) -> str:
"""
Clean the google doc link.
"""
if "google.com" in doc_link:
if "/edit" in doc_link:
return "/edit".join(doc_link.split("/edit")[:-1])
elif "/view" in doc_link:
return "/view".join(doc_link.split("/view")[:-1])
else:
return doc_link
if "app.fireflies.ai" in doc_link:
return "?".join(doc_link.split("?")[:-1])
return doc_link
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
"""
Get the score of a document from the document results.
"""
match_pos = None
for pos, comp_doc in enumerate(doc_results, start=1):
clear_doc_id = _clean_doc_id_link(doc_id)
clear_comp_doc = _clean_doc_id_link(comp_doc)
if clear_doc_id == clear_comp_doc:
match_pos = pos
if match_pos is None:
return 0.0
return _calc_score_for_pos(match_pos)
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH):
"""
Append an empty line to the CSV file.
"""
_append_answer_to_csv("", "", csv_path)
def _append_ground_truth_to_csv(
query: str,
ground_truth_docs: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for doc_id in ground_truth_docs:
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
logger.debug("Appended score to csv file")
def _append_score_to_csv(
query: str,
score: float,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", "", score])
logger.debug("Appended score to csv file")
def _append_search_results_to_csv(
query: str,
doc_results: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the search results to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in enumerate(doc_results, start=1):
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
logger.debug("Appended search results to csv file")
def _append_answer_to_csv(
query: str,
answer: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", answer, ""])
logger.debug("Appended answer to csv file")
class Answer:
def __init__(
self,
@@ -134,6 +328,9 @@ class Answer:
@property
def processed_streamed_output(self) -> AnswerStream:
_HACKATHON_TEST_EXECUTION = False
if self._processed_stream is not None:
yield from self._processed_stream
return
@@ -154,20 +351,117 @@ class Answer:
)
):
run_langgraph = run_dc_graph
elif (
self.graph_config.inputs.persona
and self.graph_config.inputs.persona.description.startswith(
"Hackathon Test"
)
):
_HACKATHON_TEST_EXECUTION = True
run_langgraph = run_hackathon_graph
else:
run_langgraph = run_basic_graph
stream = run_langgraph(self.graph_config)
if _HACKATHON_TEST_EXECUTION:
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
if input_data.startswith("["):
input_type = "json"
input_list = json.loads(input_data)
else:
input_type = "list"
input_list = input_data.split(";")
num_examples_with_ground_truth = 0
total_score = 0.0
for question_num, question_data in enumerate(input_list):
ground_truth_docs = None
if input_type == "json":
question = question_data["question"]
ground_truth = question_data.get("ground_truth")
if ground_truth:
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
logger.info(f"Question {question_num}: {question}")
_append_ground_truth_to_csv(question, ground_truth_docs)
else:
continue
else:
question = question_data
self.graph_config.inputs.prompt_builder.raw_user_query = question
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
HumanMessage(
content=question, additional_kwargs={}, response_metadata={}
),
2,
)
self.graph_config.tooling.force_use_tool.force_use = True
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
llm_answer_segments: list[str] = []
doc_results: list[str] | None = None
for answer_piece in processed_stream:
if isinstance(answer_piece, OnyxAnswerPiece):
llm_answer_segments.append(answer_piece.answer_piece or "")
elif isinstance(answer_piece, ToolCallFinalResult):
doc_results = [x.get("link") for x in answer_piece.tool_result]
if doc_results:
_append_search_results_to_csv(question, doc_results)
_append_answer_to_csv(question, "".join(llm_answer_segments))
if ground_truth_docs and doc_results:
num_examples_with_ground_truth += 1
doc_score = 0.0
for doc_id in ground_truth_docs:
doc_score += _get_doc_score(doc_id, doc_results)
_append_score_to_csv(question, doc_score)
total_score += doc_score
self._processed_stream = processed_stream
if num_examples_with_ground_truth > 0:
comprehensive_score = total_score / num_examples_with_ground_truth
else:
comprehensive_score = 0
_append_empty_line()
_append_score_to_csv(question, comprehensive_score)
else:
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:

View File

@@ -309,10 +309,7 @@ class ContextualPruningConfig(DocumentPruningConfig):
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(
num_chunk_multiple=num_chunk_multiple,
**doc_pruning_config.model_dump(),
)
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
@@ -321,6 +318,9 @@ class CitationConfig(BaseModel):
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API

View File

@@ -128,17 +128,17 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_SUMMARY_ID,
INTERNET_SEARCH_RESPONSE_ID,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchResponse,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.internet_search.models import (
InternetSearchResponseSummary,
)
from onyx.tools.tool_implementations.internet_search.utils import (
internet_search_response_to_search_docs,
)
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -281,7 +281,7 @@ def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponseSummary, packet.response)
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
@@ -296,10 +296,10 @@ def _handle_internet_search_tool_response_summary(
]
return (
QADocsResponse(
rephrased_query=internet_search_response.query,
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.INTERNET,
predicted_search=SearchType.SEMANTIC,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
@@ -491,7 +491,7 @@ def _process_tool_response(
]
)
yield FileChatDisplay(file_ids=[str(file_id) for file_id in file_ids])
elif packet.id == INTERNET_SEARCH_RESPONSE_SUMMARY_ID:
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
info.qa_docs_response,
info.reference_db_search_docs,
@@ -725,7 +725,9 @@ def stream_chat_message_objects(
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(history_msgs, new_msg_req.file_descriptors)
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
user_file_ids = new_msg_req.user_file_ids or []
@@ -904,6 +906,7 @@ def stream_chat_message_objects(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
)
@@ -933,7 +936,6 @@ def stream_chat_message_objects(
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
@@ -1010,7 +1012,6 @@ def stream_chat_message_objects(
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(

View File

@@ -187,8 +187,12 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if (
self.new_messages_and_token_cnts
and isinstance(self.user_message_and_token_cnt[0].content, str)
and self.user_message_and_token_cnt[0].content.startswith("Refer")
):
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -11,7 +11,6 @@ from onyx.chat.models import (
)
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from onyx.configs.app_configs import MAX_FEDERATED_SECTIONS
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.context.search.models import InferenceChunk
@@ -68,38 +67,6 @@ def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
return combined_ranges
def _separate_federated_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> tuple[list[InferenceSection], list[InferenceSection], list[bool] | None]:
"""
Separates out the first NUM_FEDERATED_SECTIONS federated sections to be spared from pruning.
Any remaining federated sections are treated as normal sections, and will get added if it
fits within the allocated context window. This is done as federated sections do not have
a score and would otherwise always get pruned.
"""
federated_sections: list[InferenceSection] = []
normal_sections: list[InferenceSection] = []
normal_section_relevance_list: list[bool] = []
for i, section in enumerate(sections):
if (
len(federated_sections) < MAX_FEDERATED_SECTIONS
and section.center_chunk.is_federated
):
federated_sections.append(section)
continue
normal_sections.append(section)
if section_relevance_list is not None:
normal_section_relevance_list.append(section_relevance_list[i])
return (
federated_sections[:MAX_FEDERATED_SECTIONS],
normal_sections,
normal_section_relevance_list if section_relevance_list is not None else None,
)
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
@@ -138,7 +105,7 @@ def _compute_limit(
return int(min(limit_options))
def _reorder_sections(
def reorder_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> list[InferenceSection]:
@@ -146,10 +113,11 @@ def _reorder_sections(
return sections
reordered_sections: list[InferenceSection] = []
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
if section_relevance_list is not None:
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
return reordered_sections
@@ -166,7 +134,6 @@ def _remove_sections_to_ignore(
def _apply_pruning(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
keep_sections: list[InferenceSection],
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
@@ -177,22 +144,10 @@ def _apply_pruning(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
# combine the section lists, making sure to add the keep_sections first
sections = deepcopy(keep_sections) + deepcopy(sections)
# build combined relevance list, treating the keep_sections as relevant
if section_relevance_list is not None:
section_relevance_list = [True] * len(keep_sections) + section_relevance_list
# map unique_id: relevance for final ordering step
section_id_to_relevance: dict[str, bool] = {}
if section_relevance_list is not None:
for sec, rel in zip(sections, section_relevance_list):
section_id_to_relevance[sec.center_chunk.unique_id] = rel
sections = deepcopy(sections) # don't modify in place
# re-order docs with all the "relevant" docs at the front
sections = _reorder_sections(
sections = reorder_sections(
sections=sections, section_relevance_list=section_relevance_list
)
# remove docs that are explicitly marked as not for QA
@@ -319,14 +274,6 @@ def _apply_pruning(
)
sections = [sections[0]]
# sort by relevance, then by score (as we added the keep_sections first)
sections.sort(
key=lambda s: (
not section_id_to_relevance.get(s.center_chunk.unique_id, True),
-(s.center_chunk.score or 0.0),
),
)
return sections
@@ -342,16 +289,9 @@ def prune_sections(
if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list)
# get federated sections (up to NUM_FEDERATED_SECTIONS)
# TODO: if we can somehow score the federated sections well, we don't need this
federated_sections, normal_sections, normal_section_relevance_list = (
_separate_federated_sections(sections, section_relevance_list)
)
actual_num_chunks = (
contextual_pruning_config.max_chunks
* contextual_pruning_config.num_chunk_multiple
+ len(federated_sections)
if contextual_pruning_config.max_chunks
else None
)
@@ -367,9 +307,8 @@ def prune_sections(
)
return _apply_pruning(
sections=normal_sections,
section_relevance_list=normal_section_relevance_list,
keep_sections=federated_sections,
sections=sections,
section_relevance_list=section_relevance_list,
token_limit=token_limit,
is_manually_selected_docs=contextual_pruning_config.is_manually_selected_docs,
use_sections=contextual_pruning_config.use_sections, # Now default True

View File

@@ -35,9 +35,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
# Controls whether users can use User Knowledge (personal documents) in assistants
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
# Controls whether to allow admin query history reports with:
# 1. associated user emails
# 2. anonymized user emails
@@ -311,40 +308,25 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
try:
env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY")
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_INDEXING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value)
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCPROCESSING_CONCURRENCY = (
CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1
try:
env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY")
if not env_value:
env_value = os.environ.get("NUM_DOCFETCHING_WORKERS")
if not env_value:
env_value = str(CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT)
CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value)
except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY = (
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192
VESPA_SYNC_MAX_TASKS = 1024
DB_YIELD_PER_DEFAULT = 64
@@ -468,11 +450,6 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Default size threshold for SharePoint files (20MB)
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -504,7 +481,6 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
@@ -678,14 +654,6 @@ except json.JSONDecodeError:
# LLM Model Update API endpoint
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
# Federated Search Configs
MAX_FEDERATED_SECTIONS = int(
os.environ.get("MAX_FEDERATED_SECTIONS", "5")
) # max no. of federated sections to always keep
MAX_FEDERATED_CHUNKS = int(
os.environ.get("MAX_FEDERATED_CHUNKS", "5")
) # max no. of chunks to retrieve per federated connector
#####
# Enterprise Edition Configs
#####
@@ -819,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
# Forcing Vespa Language
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
)

View File

@@ -91,10 +91,6 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
EXA_API_KEY = os.environ.get("EXA_API_KEY") or None
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
# Enable in-house model for detecting connector-based filtering in queries
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)

View File

@@ -65,8 +65,7 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
@@ -122,8 +121,6 @@ CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
# Heartbeat interval for indexing worker liveness detection
INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -189,21 +186,10 @@ class DocumentSource(str, Enum):
AIRTABLE = "airtable"
HIGHSPOT = "highspot"
IMAP = "imap"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
class FederatedConnectorSource(str, Enum):
FEDERATED_SLACK = "federated_slack"
def to_non_federated_source(self) -> DocumentSource | None:
if self == FederatedConnectorSource.FEDERATED_SLACK:
return DocumentSource.SLACK
return None
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -334,12 +320,9 @@ class OnyxCeleryQueues:
CSV_GENERATION = "csv_generation"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
USER_FILES_INDEXING = "user_files_indexing"
# Document processing pipeline queue
DOCPROCESSING = "docprocessing"
CONNECTOR_DOC_FETCHING = "connector_doc_fetching"
# Monitoring queue
MONITORING = "monitoring"
@@ -470,11 +453,7 @@ class OnyxCeleryTask:
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
# New split indexing tasks
CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task"
DOCPROCESSING_TASK = "docprocessing_task"
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"

View File

@@ -34,6 +34,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -280,28 +281,30 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# TODO: Refactor to avoid direct DB access in connector
# This will require broader refactoring across the codebase
image_section, _ = store_image_and_create_section(
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
with get_session_with_current_tenant() as db_session:
image_section, _ = store_image_and_create_section(
db_session=db_session,
image_data=downloaded_file,
file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}",
display_name=file_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
sections=[image_section],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=file_name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception:
logger.exception(f"Error processing image {key}")
continue

View File

@@ -1,17 +1,3 @@
"""
# README (notes on Confluence pagination):
We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as
opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination.
Our default pagination strategy right now for cloud is to assume cursor-based.
However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the
returned payload), then you can force offset-based pagination.
# TODO (@raunakab)
We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out.
"""
import json
import time
from collections.abc import Callable
@@ -60,13 +46,16 @@ _REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
_DEFAULT_PAGINATION_LIMIT = 1000
class ConfluenceRateLimitError(Exception):
pass
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence:
"""
This is a custom Confluence class that:
@@ -322,8 +311,8 @@ class OnyxConfluence:
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# This uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling.
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
@@ -389,6 +378,25 @@ class OnyxConfluence:
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
@@ -475,7 +483,6 @@ class OnyxConfluence:
limit: int | None = None,
# Called with the next url to use to get the next page
next_page_callback: Callable[[str], None] | None = None,
force_offset_pagination: bool = False,
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -491,10 +498,6 @@ class OnyxConfluence:
raw_response = self.get(
path=url_suffix,
advanced_mode=True,
params={
"body-format": "atlas_doc_format",
"expand": "body.atlas_doc_format",
},
)
except Exception as e:
logger.exception(f"Error in confluence call to {url_suffix}")
@@ -561,32 +564,14 @@ class OnyxConfluence:
)
raise e
# Yield the results individually.
# yield the results individually
results = cast(list[dict[str, Any]], next_response.get("results", []))
# Note 1:
# Make sure we don't update the start by more than the amount
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
#
# Note 2:
# We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs`
# because we *have to call the `next_page_callback`* prior to yielding the last element!
#
# If we did:
#
# ```py
# yield from results
# if next_page_callback:
# next_page_callback(url_suffix)
# ```
#
# then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving
# loop) prior to the callback being called.
old_url_suffix = url_suffix
updated_start = get_start_param_from_url(old_url_suffix)
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
@@ -602,12 +587,6 @@ class OnyxConfluence:
)
# notify the caller of the new url
next_page_callback(url_suffix)
elif force_offset_pagination and i == len(results) - 1:
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
yield result
# we've observed that Confluence sometimes returns a next link despite giving
@@ -705,9 +684,7 @@ class OnyxConfluence:
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(
url, limit, force_offset_pagination=True
):
for user_result in self._paginate_url(url, limit):
# Example response:
# {
# 'user': {
@@ -797,7 +774,7 @@ class OnyxConfluence:
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit, force_offset_pagination=True)
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
@@ -949,9 +926,6 @@ def extract_text_from_confluence_html(
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
_remove_macro_stylings(soup=soup)
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
@@ -1033,15 +1007,3 @@ def extract_text_from_confluence_html(
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
def _remove_macro_stylings(soup: bs4.BeautifulSoup) -> None:
for macro_root in soup.findAll("ac:structured-macro"):
if not isinstance(macro_root, bs4.Tag):
continue
macro_styling = macro_root.find(name="ac:parameter", attrs={"ac:name": "page"})
if not macro_styling or not isinstance(macro_styling, bs4.Tag):
continue
macro_styling.extract()

View File

@@ -23,6 +23,7 @@ from onyx.configs.app_configs import (
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
@@ -223,17 +224,19 @@ def _process_image_attachment(
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
section, file_name = store_image_and_create_section(
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_id=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)

View File

@@ -109,10 +109,8 @@ def process_onyx_metadata(
return (
OnyxMetadata(
source_type=metadata.get("connector_type"),
link=metadata.get("link"),
file_display_name=metadata.get("file_display_name"),
title=metadata.get("title"),
primary_owners=p_owners,
secondary_owners=s_owners,
doc_updated_at=doc_updated_at,

View File

@@ -33,7 +33,6 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.imap.connector import ImapConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CredentialsConnector
@@ -122,7 +121,6 @@ def identify_connector_class(
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
DocumentSource.IMAP: ImapConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}

View File

@@ -5,6 +5,8 @@ from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
@@ -16,6 +18,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
@@ -29,6 +32,7 @@ logger = setup_logger()
def _create_image_section(
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
@@ -54,6 +58,7 @@ def _create_image_section(
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_id=file_id,
display_name=display_name,
@@ -72,6 +77,7 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -119,6 +125,7 @@ def _process_file(
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=file_id,
display_name=title,
)
@@ -164,12 +171,10 @@ def _process_file(
custom_tags.update(more_custom_tags)
# File-specific metadata overrides metadata processed so far
source_type = onyx_metadata.source_type or source_type
primary_owners = onyx_metadata.primary_owners or primary_owners
secondary_owners = onyx_metadata.secondary_owners or secondary_owners
time_updated = onyx_metadata.doc_updated_at or time_updated
file_display_name = onyx_metadata.file_display_name or file_display_name
title = onyx_metadata.title or onyx_metadata.file_display_name or title
link = onyx_metadata.link or link
# Build sections: first the text as a single Section
@@ -189,6 +194,7 @@ def _process_file(
try:
image_section, stored_file_name = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=file_id,
display_name=f"{title} - image {idx}",
idx=idx,
@@ -252,33 +258,37 @@ class LocalFileConnector(LoadConnector):
"""
documents: list[Document] = []
for file_id in self.file_locations:
file_store = get_default_file_store()
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
with get_session_with_current_tenant() as db_session:
for file_id in self.file_locations:
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(file_id=file_id)
if not file_record:
# typically an unsupported extension
logger.warning(
f"No file record found for '{file_id}' in PG; skipping."
)
continue
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
)
documents.extend(new_docs)
metadata = self._get_file_metadata(file_id)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
file_name=file_record.display_name,
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
connector = LocalFileConnector(

View File

@@ -35,7 +35,6 @@ _FIREFLIES_API_QUERY = """
organizer_email
participants
date
duration
transcript_url
sentences {
text
@@ -102,14 +101,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={
k: str(v)
for k, v in {
"meeting_date": meeting_date,
"duration_min": transcript.get("duration"),
}.items()
if v is not None
},
metadata={},
doc_updated_at=meeting_date,
primary_owners=organizer_email_user_info,
secondary_owners=meeting_participants_email_list,

View File

@@ -1,10 +1,6 @@
import copy
import json
import os
import sys
import threading
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
@@ -207,9 +203,7 @@ class GoogleDriveConnector(
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
self.specific_requests_made = specific_requests_made
# NOTE: potentially modified in load_credentials if using service account
self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
@@ -290,16 +284,6 @@ class GoogleDriveConnector(
source=DocumentSource.GOOGLE_DRIVE,
)
# Service account connectors don't have a specific setting determining whether
# to include "shared with me" for each user, so we default to true unless the connector
# is in specific folders/drives mode. Note that shared files are only picked up during
# the My Drive stage, so this does nothing if the connector is set to only index shared drives.
if (
isinstance(self._creds, ServiceAccountCredentials)
and not self.specific_requests_made
):
self.include_files_shared_with_me = True
self._creds_dict = new_creds_dict
return new_creds_dict
@@ -1378,139 +1362,3 @@ class GoogleDriveConnector(
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint:
return GoogleDriveCheckpoint.model_validate_json(checkpoint_json)
def get_credentials_from_env(email: str, oauth: bool) -> dict:
if oauth:
raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"]
else:
raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"]
refried_credential_string = json.dumps(json.loads(raw_credential_string))
# This is the Oauth token
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens"
# This is the service account key
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
cred_key = (
DB_CREDENTIALS_DICT_TOKEN_KEY
if oauth
else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
)
return {
cred_key: refried_credential_string,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded",
}
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: GoogleDriveCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint],
) -> CheckpointOutput[GoogleDriveCheckpoint]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
def yield_all_docs_from_checkpoint_connector(
connector: GoogleDriveConnector,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> Iterator[Document | ConnectorFailure]:
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
yield failure
if document is not None:
yield document
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > 100_000:
raise RuntimeError("Too many iterations. Infinite loop?")
if __name__ == "__main__":
import time
creds = get_credentials_from_env(
os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False
)
connector = GoogleDriveConnector(
include_shared_drives=True,
shared_drive_urls=None,
include_my_drives=True,
my_drive_emails=None,
shared_folder_urls=None,
include_files_shared_with_me=True,
specific_user_emails=None,
)
connector.load_credentials(creds)
max_fsize = 0
biggest_fsize = 0
num_errors = 0
start_time = time.time()
with open("stats.txt", "w") as f:
for num, doc_or_failure in enumerate(
yield_all_docs_from_checkpoint_connector(connector, 0, time.time())
):
if num % 200 == 0:
f.write(f"Processed {num} files\n")
f.write(f"Max file size: {max_fsize/1000_000:.2f} MB\n")
f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n")
f.write(f"Docs per minute: {num/(time.time() - start_time)*60:.2f}\n")
biggest_fsize = max(biggest_fsize, max_fsize)
max_fsize = 0
if isinstance(doc_or_failure, Document):
max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure))
elif isinstance(doc_or_failure, ConnectorFailure):
num_errors += 1
print(f"Num errors: {num_errors}")
print(f"Biggest file size: {biggest_fsize/1000_000:.2f} MB")
print(f"Time taken: {time.time() - start_time:.2f} seconds")

View File

@@ -3,8 +3,6 @@ from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import cast
from urllib.parse import urlparse
from urllib.parse import urlunparse
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
@@ -29,6 +27,7 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
@@ -78,15 +77,7 @@ class PermissionSyncContext(BaseModel):
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
link = file[WEB_VIEW_LINK_KEY]
parsed_url = urlparse(link)
parsed_url = parsed_url._replace(query="") # remove query parameters
spl_path = parsed_url.path.split("/")
if spl_path and (spl_path[-1] in ["edit", "view", "preview"]):
spl_path.pop()
parsed_url = parsed_url._replace(path="/".join(spl_path))
# Remove query parameters and reconstruct URL
return urlunparse(parsed_url)
return file[WEB_VIEW_LINK_KEY]
def is_gdrive_image_mime_type(mime_type: str) -> bool:
@@ -128,32 +119,9 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get(WEB_VIEW_LINK_KEY, "")
# For non-Google files, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id)
if is_gdrive_image_mime_type(mime_type):
# Skip images if not explicitly enabled
if not allow_images:
return []
# Store images for later processing
sections: list[TextSection | ImageSection] = []
try:
section, embedded_id = store_image_and_create_section(
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
@@ -176,6 +144,12 @@ def _download_and_extract_sections_basic(
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
# lazy evaluation to only download the file if necessary
def response_call() -> bytes:
return download_request(service, file_id)
# Process based on mime type
if mime_type == "text/plain":
try:
@@ -205,6 +179,25 @@ def _download_and_extract_sections_basic(
text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name)
return [TextSection(link=link, text=text)] if text else []
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_id=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
@@ -213,30 +206,41 @@ def _download_and_extract_sections_basic(
# Process embedded images in the PDF
try:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_id=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
# don't download the file at all if it's an unhandled extension
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int:
@@ -311,17 +315,13 @@ def align_basic_advanced(
def _get_external_access_for_raw_gdrive_file(
file: GoogleDriveFileType,
company_domain: str,
retriever_drive_service: GoogleDriveService | None,
admin_drive_service: GoogleDriveService,
drive_service: GoogleDriveService,
) -> ExternalAccess:
"""
Get the external access for a raw Google Drive file.
"""
external_access_fn = cast(
Callable[
[GoogleDriveFileType, str, GoogleDriveService | None, GoogleDriveService],
ExternalAccess,
],
Callable[[GoogleDriveFileType, str, GoogleDriveService], ExternalAccess],
fetch_versioned_implementation_with_fallback(
"onyx.external_permissions.google_drive.doc_sync",
"get_external_access_for_raw_gdrive_file",
@@ -331,8 +331,7 @@ def _get_external_access_for_raw_gdrive_file(
return external_access_fn(
file,
company_domain,
retriever_drive_service,
admin_drive_service,
drive_service,
)
@@ -437,19 +436,6 @@ def _convert_drive_item_to_document(
logger.info("Skipping shortcut/folder.")
return None
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If it's a Google Doc, we might do advanced parsing
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
@@ -475,8 +461,22 @@ def _convert_drive_item_to_document(
logger.warning(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
# Not Google Doc, attempt basic extraction
else:
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _download_and_extract_sections_basic(
file, _get_drive_service(), allow_images
)
@@ -491,9 +491,7 @@ def _convert_drive_item_to_document(
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
# try both retriever_email and primary_admin_email if necessary
retriever_drive_service=_get_drive_service(),
admin_drive_service=get_drive_service(
drive_service=get_drive_service(
creds, user_email=permission_sync_context.primary_admin_email
),
)
@@ -559,22 +557,14 @@ def build_slim_document(
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress"))
owner_email = file.get("owners", [{}])[0].get("emailAddress")
external_access = (
_get_external_access_for_raw_gdrive_file(
file=file,
company_domain=permission_sync_context.google_domain,
retriever_drive_service=(
get_drive_service(
creds,
user_email=owner_email,
)
if owner_email
else None
),
admin_drive_service=get_drive_service(
drive_service=get_drive_service(
creds,
user_email=permission_sync_context.primary_admin_email,
user_email=owner_email or permission_sync_context.primary_admin_email,
),
)
if permission_sync_context

View File

@@ -12,6 +12,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.html_utils import web_html_cleanup
@@ -67,7 +68,10 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
with get_session_with_current_tenant() as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
# load the HTML files
files = load_files_from_zip(file_content_io)

View File

@@ -1,484 +0,0 @@
import copy
import email
import imaplib
import os
import re
from datetime import datetime
from datetime import timezone
from email.message import Message
from email.utils import parseaddr
from enum import Enum
from typing import Any
from typing import cast
import bs4
from pydantic import BaseModel
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.imap.models import EmailHeaders
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
_DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993))
_IMAP_OKAY_STATUS = "OK"
_PAGE_SIZE = 100
_USERNAME_KEY = "imap_username"
_PASSWORD_KEY = "imap_password"
class CurrentMailbox(BaseModel):
mailbox: str
todo_email_ids: list[str]
# An email has a list of mailboxes.
# Each mailbox has a list of email-ids inside of it.
#
# Usage:
# To use this checkpointer, first fetch all the mailboxes.
# Then, pop a mailbox and fetch all of its email-ids.
# Then, pop each email-id and fetch its content (and parse it, etc..).
# When you have popped all email-ids for this mailbox, pop the next mailbox and repeat the above process until you're done.
#
# For initial checkpointing, set both fields to `None`.
class ImapCheckpoint(ConnectorCheckpoint):
todo_mailboxes: list[str] | None = None
current_mailbox: CurrentMailbox | None = None
class LoginState(str, Enum):
LoggedIn = "logged_in"
LoggedOut = "logged_out"
class ImapConnector(
CredentialsConnector,
CheckpointedConnectorWithPermSync[ImapCheckpoint],
):
def __init__(
self,
host: str,
port: int = _DEFAULT_IMAP_PORT_NUMBER,
mailboxes: list[str] | None = None,
) -> None:
self._host = host
self._port = port
self._mailboxes = mailboxes
self._credentials: dict[str, Any] | None = None
@property
def credentials(self) -> dict[str, Any]:
if not self._credentials:
raise RuntimeError(
"Credentials have not been initialized; call `set_credentials_provider` first"
)
return self._credentials
def _get_mail_client(self) -> imaplib.IMAP4_SSL:
"""
Returns a new `imaplib.IMAP4_SSL` instance.
The `imaplib.IMAP4_SSL` object is supposed to be an "ephemeral" object; it's not something that you can login,
logout, then log back into again. I.e., the following will fail:
```py
mail_client.login(..)
mail_client.logout();
mail_client.login(..)
```
Therefore, you need a fresh, new instance in order to operate with IMAP. This function gives one to you.
# Notes
This function will throw an error if the credentials have not yet been set.
"""
def get_or_raise(name: str) -> str:
value = self.credentials.get(name)
if not value:
raise RuntimeError(f"Credential item {name=} was not found")
if not isinstance(value, str):
raise RuntimeError(
f"Credential item {name=} must be of type str, instead received {type(name)=}"
)
return value
username = get_or_raise(_USERNAME_KEY)
password = get_or_raise(_PASSWORD_KEY)
mail_client = imaplib.IMAP4_SSL(host=self._host, port=self._port)
status, _data = mail_client.login(user=username, password=password)
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to log into imap server; {status=}")
return mail_client
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
include_perm_sync: bool,
) -> CheckpointOutput[ImapCheckpoint]:
checkpoint = cast(ImapCheckpoint, copy.deepcopy(checkpoint))
checkpoint.has_more = True
mail_client = self._get_mail_client()
if checkpoint.todo_mailboxes is None:
# This is the dummy checkpoint.
# Fill it with mailboxes first.
if self._mailboxes:
checkpoint.todo_mailboxes = _sanitize_mailbox_names(self._mailboxes)
else:
fetched_mailboxes = _fetch_all_mailboxes_for_email_account(
mail_client=mail_client
)
if not fetched_mailboxes:
raise RuntimeError(
"Failed to find any mailboxes for this email account"
)
checkpoint.todo_mailboxes = _sanitize_mailbox_names(fetched_mailboxes)
return checkpoint
if (
not checkpoint.current_mailbox
or not checkpoint.current_mailbox.todo_email_ids
):
if not checkpoint.todo_mailboxes:
checkpoint.has_more = False
return checkpoint
mailbox = checkpoint.todo_mailboxes.pop()
email_ids = _fetch_email_ids_in_mailbox(
mail_client=mail_client,
mailbox=mailbox,
start=start,
end=end,
)
checkpoint.current_mailbox = CurrentMailbox(
mailbox=mailbox,
todo_email_ids=email_ids,
)
_select_mailbox(
mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox
)
current_todos = cast(
list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE])
)
checkpoint.current_mailbox.todo_email_ids = (
checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:]
)
for email_id in current_todos:
email_msg = _fetch_email(mail_client=mail_client, email_id=email_id)
if not email_msg:
logger.warn(f"Failed to fetch message {email_id=}; skipping")
continue
email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
yield _convert_email_headers_and_body_into_document(
email_msg=email_msg,
email_headers=email_headers,
include_perm_sync=include_perm_sync,
)
return checkpoint
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use `set_credentials_provider` instead")
def validate_connector_settings(self) -> None:
self._get_mail_client()
# impls for CredentialsConnector
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self._credentials = credentials_provider.get_credentials()
# impls for CheckpointedConnector
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=False
)
def build_dummy_checkpoint(self) -> ImapCheckpoint:
return ImapCheckpoint(has_more=True)
def validate_checkpoint_json(self, checkpoint_json: str) -> ImapCheckpoint:
return ImapCheckpoint.model_validate_json(json_data=checkpoint_json)
# impls for CheckpointedConnectorWithPermSync
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ImapCheckpoint,
) -> CheckpointOutput[ImapCheckpoint]:
return self._load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint, include_perm_sync=True
)
def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]:
status, mailboxes_data = mail_client.list(directory="*", pattern="*")
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to fetch mailboxes; {status=}")
mailboxes = []
for mailboxes_raw in mailboxes_data:
if isinstance(mailboxes_raw, bytes):
mailboxes_str = mailboxes_raw.decode()
elif isinstance(mailboxes_raw, str):
mailboxes_str = mailboxes_raw
else:
logger.warn(
f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping"
)
continue
# The mailbox LIST response output can be found here:
# https://www.rfc-editor.org/rfc/rfc3501.html#section-7.2.2
#
# The general format is:
# `(<name-attributes>) <hierarchy-delimiter> <mailbox-name>`
#
# The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name.
match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str)
if not match:
logger.warn(
f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping"
)
continue
mailbox = match.group(2)
mailboxes.append(mailbox)
return mailboxes
def _select_mailbox(mail_client: imaplib.IMAP4_SSL, mailbox: str) -> None:
status, _ids = mail_client.select(mailbox=mailbox, readonly=True)
if status != _IMAP_OKAY_STATUS:
raise RuntimeError(f"Failed to select {mailbox=}")
def _fetch_email_ids_in_mailbox(
mail_client: imaplib.IMAP4_SSL,
mailbox: str,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[str]:
_select_mailbox(mail_client=mail_client, mailbox=mailbox)
start_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime("%d-%b-%Y")
end_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime("%d-%b-%Y")
search_criteria = f'(SINCE "{start_str}" BEFORE "{end_str}")'
status, email_ids_byte_array = mail_client.search(None, search_criteria)
if status != _IMAP_OKAY_STATUS or not email_ids_byte_array:
raise RuntimeError(f"Failed to fetch email ids; {status=}")
email_ids: bytes = email_ids_byte_array[0]
return [email_id.decode() for email_id in email_ids.split()]
def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | None:
status, msg_data = mail_client.fetch(message_set=email_id, message_parts="(RFC822)")
if status != _IMAP_OKAY_STATUS or not msg_data:
return None
data = msg_data[0]
if not isinstance(data, tuple):
raise RuntimeError(
f"Message data should be a tuple; instead got a {type(data)=} {data=}"
)
_metadata, raw_email = data
return email.message_from_bytes(raw_email)
def _convert_email_headers_and_body_into_document(
email_msg: Message,
email_headers: EmailHeaders,
include_perm_sync: bool,
) -> Document:
sender_name, sender_addr = _parse_singular_addr(raw_header=email_headers.sender)
parsed_recipients = (
_parse_addrs(raw_header=email_headers.recipients)
if email_headers.recipients
else []
)
expert_info_map = {
recipient_addr: BasicExpertInfo(
display_name=recipient_name, email=recipient_addr
)
for recipient_name, recipient_addr in parsed_recipients
}
if sender_addr not in expert_info_map:
expert_info_map[sender_addr] = BasicExpertInfo(
display_name=sender_name, email=sender_addr
)
email_body = _parse_email_body(email_msg=email_msg, email_headers=email_headers)
primary_owners = list(expert_info_map.values())
external_access = (
ExternalAccess(
external_user_emails=set(expert_info_map.keys()),
external_user_group_ids=set(),
is_public=False,
)
if include_perm_sync
else None
)
return Document(
id=email_headers.id,
title=email_headers.subject,
semantic_identifier=email_headers.subject,
metadata={},
source=DocumentSource.IMAP,
sections=[TextSection(text=email_body)],
primary_owners=primary_owners,
external_access=external_access,
)
def _parse_email_body(
email_msg: Message,
email_headers: EmailHeaders,
) -> str:
body = None
for part in email_msg.walk():
if part.is_multipart():
# Multipart parts are *containers* for other parts, not the actual content itself.
# Therefore, we skip until we find the individual parts instead.
continue
charset = part.get_content_charset() or "utf-8"
try:
raw_payload = part.get_payload(decode=True)
if not isinstance(raw_payload, bytes):
logger.warn(
"Payload section from email was expected to be an array of bytes, instead got "
f"{type(raw_payload)=}, {raw_payload=}"
)
continue
body = raw_payload.decode(charset)
break
except (UnicodeDecodeError, LookupError) as e:
print(f"Warning: Could not decode part with charset {charset}. Error: {e}")
continue
if not body:
logger.warn(
f"Email with {email_headers.id=} has an empty body; returning an empty string"
)
return ""
soup = bs4.BeautifulSoup(markup=body, features="html.parser")
return " ".join(str_section for str_section in soup.stripped_strings)
def _sanitize_mailbox_names(mailboxes: list[str]) -> list[str]:
"""
Mailboxes with special characters in them must be enclosed by double-quotes, as per the IMAP protocol.
Just to be safe, we wrap *all* mailboxes with double-quotes.
"""
return [f'"{mailbox}"' for mailbox in mailboxes if mailbox]
def _parse_addrs(raw_header: str) -> list[tuple[str, str]]:
addrs = raw_header.split(",")
name_addr_pairs = [parseaddr(addr=addr, strict=True) for addr in addrs if addr]
return [(name, addr) for name, addr in name_addr_pairs if addr]
def _parse_singular_addr(raw_header: str) -> tuple[str, str]:
addrs = _parse_addrs(raw_header=raw_header)
if not addrs:
raise RuntimeError(
f"Parsing email header resulted in no addresses being found; {raw_header=}"
)
elif len(addrs) >= 2:
raise RuntimeError(
f"Expected a singular address, but instead got multiple; {raw_header=} {addrs=}"
)
return addrs[0]
if __name__ == "__main__":
import time
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
host = os.environ.get("IMAP_HOST")
mailboxes_str = os.environ.get("IMAP_MAILBOXES")
username = os.environ.get("IMAP_USERNAME")
password = os.environ.get("IMAP_PASSWORD")
mailboxes = (
[mailbox.strip() for mailbox in mailboxes_str.split(",")]
if mailboxes_str
else []
)
if not host:
raise RuntimeError("`IMAP_HOST` must be set")
imap_connector = ImapConnector(
host=host,
mailboxes=mailboxes,
)
imap_connector.set_credentials_provider(
OnyxStaticCredentialsProvider(
tenant_id=None,
connector_name=DocumentSource.IMAP,
credential_json={
_USERNAME_KEY: username,
_PASSWORD_KEY: password,
},
)
)
for doc in load_all_docs_from_checkpoint_connector(
connector=imap_connector,
start=0,
end=time.time(),
):
print(doc)

View File

@@ -1,75 +0,0 @@
import email
from datetime import datetime
from email.message import Message
from enum import Enum
from pydantic import BaseModel
class Header(str, Enum):
SUBJECT_HEADER = "subject"
FROM_HEADER = "from"
TO_HEADER = "to"
DELIVERED_TO_HEADER = (
"Delivered-To" # Used in mailing lists instead of the "to" header.
)
DATE_HEADER = "date"
MESSAGE_ID_HEADER = "Message-ID"
class EmailHeaders(BaseModel):
"""
Model for email headers extracted from IMAP messages.
"""
id: str
subject: str
sender: str
recipients: str | None
date: datetime
@classmethod
def from_email_msg(cls, email_msg: Message) -> "EmailHeaders":
def _decode(header: str, default: str | None = None) -> str | None:
value = email_msg.get(header, default)
if not value:
return None
decoded_value, _encoding = email.header.decode_header(value)[0]
if isinstance(decoded_value, bytes):
return decoded_value.decode()
elif isinstance(decoded_value, str):
return decoded_value
else:
return None
def _parse_date(date_str: str | None) -> datetime | None:
if not date_str:
return None
try:
return email.utils.parsedate_to_datetime(date_str)
except (TypeError, ValueError):
return None
message_id = _decode(header=Header.MESSAGE_ID_HEADER)
# It's possible for the subject line to not exist or be an empty string.
subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject"
from_ = _decode(header=Header.FROM_HEADER)
to = _decode(header=Header.TO_HEADER)
if not to:
to = _decode(header=Header.DELIVERED_TO_HEADER)
date_str = _decode(header=Header.DATE_HEADER)
date = _parse_date(date_str=date_str)
# If any of the above are `None`, model validation will fail.
# Therefore, no guards (i.e.: `if <header> is None: raise RuntimeError(..)`) were written.
return cls.model_validate(
{
"id": message_id,
"subject": subject,
"sender": from_,
"recipients": to,
"date": date,
}
)

View File

@@ -129,14 +129,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(
self.site,
(
f"{category.replace(' ', '_')}"
if category.startswith("Category:")
else f"Category:{category.replace(' ', '_')}"
),
)
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
for category in categories
]

View File

@@ -11,7 +11,6 @@ from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.db.enums import IndexModelStatus
from onyx.utils.text_processing import make_url_compatible
@@ -364,38 +363,9 @@ class ConnectorFailure(BaseModel):
return values
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class OnyxMetadata(BaseModel):
# Note that doc_id cannot be overriden here as it may cause issues
# with the display functionalities in the UI. Ask @chris if clarification is needed.
source_type: DocumentSource | None = None
link: str | None = None
file_display_name: str | None = None
primary_owners: list[BasicExpertInfo] | None = None
secondary_owners: list[BasicExpertInfo] | None = None
doc_updated_at: datetime | None = None
title: str | None = None
class DocExtractionContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
should_fetch_permissions_during_indexing: bool
search_settings_status: IndexModelStatus
doc_extraction_complete_batch_num: int | None
class DocIndexingContext(BaseModel):
batches_done: int
total_failures: int
net_doc_change: int
total_chunks: int

View File

@@ -267,7 +267,7 @@ class NotionConnector(LoadConnector, PollConnector):
result = ""
for prop_name, prop in properties.items():
if not prop or not isinstance(prop, dict):
if not prop:
continue
try:

View File

@@ -992,7 +992,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.safe_query_all(query)
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",

View File

@@ -1,31 +1,18 @@
import time
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES
from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
class OnyxSalesforce(Salesforce):
SOQL_MAX_SUBQUERIES = 20
@@ -65,48 +52,6 @@ class OnyxSalesforce(Salesforce):
return False
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query method with retry logic and rate limiting."""
try:
return super().query(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@retry_builder(
tries=5,
delay=20,
backoff=1.5,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]:
"""Wrapper around the original query_all method with retry logic and rate limiting."""
try:
return super().query_all(query, **kwargs)
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for query_all: {query[:100]}..."
)
# Add additional delay for rate limit errors
time.sleep(5)
raise
@staticmethod
def _make_child_objects_by_id_query(
object_id: str,
@@ -154,7 +99,7 @@ class OnyxSalesforce(Salesforce):
queryable_fields = type_to_queryable_fields[object_type]
query = get_object_by_id_query(object_id, object_type, queryable_fields)
result = self.safe_query(query)
result = self.query(query)
if not result:
return None
@@ -206,7 +151,7 @@ class OnyxSalesforce(Salesforce):
)
try:
result = self.safe_query(query)
result = self.query(query)
except Exception:
logger.exception(f"Query failed: {query=}")
else:
@@ -244,25 +189,10 @@ class OnyxSalesforce(Salesforce):
return child_records
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(SalesforceRefusedRequest,),
)
def describe_type(self, name: str) -> Any:
sf_object = SFType(name, self.session_id, self.sf_instance)
try:
result = sf_object.describe()
return result
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for describe_type: {name}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
result = sf_object.describe()
return result
def get_queryable_fields_by_type(self, name: str) -> list[str]:
object_description = self.describe_type(name)

View File

@@ -1,6 +1,5 @@
import gc
import os
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
@@ -8,25 +7,13 @@ from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from simple_salesforce.exceptions import SalesforceRefusedRequest
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def is_salesforce_rate_limit_error(exception: Exception) -> bool:
"""Check if an exception is a Salesforce rate limit error."""
return isinstance(
exception, SalesforceRefusedRequest
) and "REQUEST_LIMIT_EXCEEDED" in str(exception)
def _build_last_modified_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
@@ -84,14 +71,6 @@ def get_object_by_id_query(
return query
@retry_builder(
tries=5,
delay=2,
backoff=2,
max_delay=60,
exceptions=(SalesforceRefusedRequest,),
)
@rate_limit_builder(max_calls=50, period=60)
def _object_type_has_api_data(
sf_client: Salesforce, sf_type: str, time_filter: str
) -> bool:
@@ -103,15 +82,6 @@ def _object_type_has_api_data(
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except SalesforceRefusedRequest as e:
if is_salesforce_rate_limit_error(e):
logger.warning(
f"Salesforce rate limit exceeded for object type check: {sf_type}"
)
# Add additional delay for rate limit errors
time.sleep(3)
raise
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")

View File

@@ -1,6 +1,5 @@
import io
import os
import time
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -12,11 +11,9 @@ from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore
from office365.runtime.client_request import ClientRequestException # type: ignore
from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
@@ -49,72 +46,12 @@ class SiteDescriptor(BaseModel):
folder_path: str | None
def _sleep_and_retry(query_obj: Any, method_name: str, max_retries: int = 3) -> Any:
"""
Execute a SharePoint query with retry logic for rate limiting.
"""
for attempt in range(max_retries + 1):
try:
return query_obj.execute_query()
except ClientRequestException as e:
if (
e.response
and e.response.status_code in [429, 503]
and attempt < max_retries
):
logger.warning(
f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying"
)
retry_after = e.response.headers.get("Retry-After")
if retry_after:
sleep_time = int(retry_after)
else:
# Exponential backoff: 2^attempt * 5 seconds
sleep_time = min(30, (2**attempt) * 5)
logger.info(f"Sleeping for {sleep_time} seconds before retry")
time.sleep(sleep_time)
else:
# Either not a rate limit error, or we've exhausted retries
if e.response and e.response.status_code == 429:
logger.error(
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
)
raise e
def _convert_driveitem_to_document(
driveitem: DriveItem,
drive_name: str,
) -> Document | None:
# Check file size before downloading
try:
size_value = getattr(driveitem, "size", None)
if size_value is not None:
file_size = int(size_value)
if file_size > SHAREPOINT_CONNECTOR_SIZE_THRESHOLD:
logger.warning(
f"File '{driveitem.name}' exceeds size threshold of {SHAREPOINT_CONNECTOR_SIZE_THRESHOLD} bytes. "
f"File size: {file_size} bytes. Skipping."
)
return None
else:
logger.warning(
f"Could not access file size for '{driveitem.name}' Proceeding with download."
)
except (ValueError, TypeError, AttributeError) as e:
logger.info(
f"Could not access file size for '{driveitem.name}': {e}. Proceeding with download."
)
# Proceed with download if size is acceptable or not available
content = _sleep_and_retry(driveitem.get_content(), "get_content")
if content is None:
logger.warning(f"Could not access content for '{driveitem.name}'")
return None
) -> Document:
file_text = extract_file_text(
file=io.BytesIO(content.value),
file=io.BytesIO(driveitem.get_content().execute_query().value),
file_name=driveitem.name,
break_on_unprocessable=False,
)
@@ -338,11 +275,7 @@ class SharepointConnector(LoadConnector, PollConnector):
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
# Convert driveitem to document with size checking
doc = _convert_driveitem_to_document(driveitem, drive_name)
if doc is not None:
doc_batch.append(doc)
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
if len(doc_batch) >= self.batch_size:
yield doc_batch

View File

@@ -286,14 +286,9 @@ class TeamsConnector(
def _construct_semantic_identifier(channel: Channel, top_message: Message) -> str:
top_message_user_name: str
if top_message.from_ and top_message.from_.user:
top_message_user_name = top_message.from_.user.display_name
else:
logger.warn(f"Message {top_message=} has no `from.user` field")
top_message_user_name = "Unknown User"
top_message_user_name = (
top_message.from_.user.display_name if top_message.from_ else "Unknown User"
)
top_message_content = top_message.body.content or ""
top_message_subject = top_message.subject or "Unknown Subject"
channel_name = channel.properties.get("displayName", "Unknown")

View File

@@ -27,7 +27,7 @@ class User(BaseModel):
class From(BaseModel):
user: User | None
user: User
model_config = ConfigDict(
alias_generator=to_camel,

View File

@@ -23,7 +23,6 @@ class OptionalSearchSetting(str, Enum):
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
INTERNET = "internet"
class LLMEvaluationType(str, Enum):

View File

@@ -1,18 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
class SlackMessage(BaseModel):
document_id: str
channel_id: str
message_id: str
thread_id: str | None
link: str
metadata: dict[str, str | list[str]]
timestamp: datetime
recency_bias: float
semantic_identifier: str
text: str
highlighted_texts: set[str]
slack_score: float

View File

@@ -1,416 +0,0 @@
import re
from datetime import datetime
from datetime import timedelta
from typing import Any
from langchain_core.messages import HumanMessage
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
from onyx.context.search.federated.models import SlackMessage
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import SearchQuery
from onyx.db.document import DocumentSource
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
)
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
logger = setup_logger()
HIGHLIGHT_START_CHAR = "\ue000"
HIGHLIGHT_END_CHAR = "\ue001"
def build_slack_queries(query: SearchQuery, llm: LLM) -> list[str]:
# get time filter
time_filter = ""
time_cutoff = query.filters.time_cutoff
if time_cutoff is not None:
# slack after: is exclusive, so we need to subtract one day
time_cutoff = time_cutoff - timedelta(days=1)
time_filter = f" after:{time_cutoff.strftime('%Y-%m-%d')}"
# use llm to generate slack queries (use original query to use same keywords as the user)
prompt = SLACK_QUERY_EXPANSION_PROMPT.format(query=query.original_query)
try:
msg = HumanMessage(content=prompt)
response = llm.invoke([msg])
rephrased_queries = message_to_string(response).split("\n")
except Exception as e:
logger.error(f"Error expanding query: {e}")
rephrased_queries = [query.query]
return [
rephrased_query.strip() + time_filter
for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
]
def query_slack(
query_string: str,
original_query: SearchQuery,
access_token: str,
limit: int | None = None,
) -> list[SlackMessage]:
# query slack
slack_client = WebClient(token=access_token)
try:
response = slack_client.search_messages(
query=query_string, count=limit, highlight=True
)
response.validate()
messages: dict[str, Any] = response.get("messages", {})
matches: list[dict[str, Any]] = messages.get("matches", [])
except SlackApiError as e:
logger.error(f"Slack API error in query_slack: {e}")
return []
# convert matches to slack messages
slack_messages: list[SlackMessage] = []
for match in matches:
text: str | None = match.get("text")
permalink: str | None = match.get("permalink")
message_id: str | None = match.get("ts")
channel_id: str | None = match.get("channel", {}).get("id")
channel_name: str | None = match.get("channel", {}).get("name")
username: str | None = match.get("username")
score: float = match.get("score", 0.0)
if ( # can't use any() because of type checking :(
not text
or not permalink
or not message_id
or not channel_id
or not channel_name
or not username
):
continue
# generate thread id and document id
thread_id = (
permalink.split("?thread_ts=", 1)[1] if "?thread_ts=" in permalink else None
)
document_id = f"{channel_id}_{message_id}"
# compute recency bias (parallels vespa calculation) and metadata
decay_factor = DOC_TIME_DECAY * original_query.recency_bias_multiplier
doc_time = datetime.fromtimestamp(float(message_id))
doc_age_years = (datetime.now() - doc_time).total_seconds() / (
365 * 24 * 60 * 60
)
recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75)
metadata: dict[str, str | list[str]] = {
"channel": channel_name,
"time": doc_time.isoformat(),
}
# extract out the highlighted texts
highlighted_texts = set(
re.findall(
rf"{re.escape(HIGHLIGHT_START_CHAR)}(.*?){re.escape(HIGHLIGHT_END_CHAR)}",
text,
)
)
cleaned_text = text.replace(HIGHLIGHT_START_CHAR, "").replace(
HIGHLIGHT_END_CHAR, ""
)
# get the semantic identifier
snippet = (
cleaned_text[:50].rstrip() + "..." if len(cleaned_text) > 50 else text
).replace("\n", " ")
doc_sem_id = f"{username} in #{channel_name}: {snippet}"
slack_messages.append(
SlackMessage(
document_id=document_id,
channel_id=channel_id,
message_id=message_id,
thread_id=thread_id,
link=permalink,
metadata=metadata,
timestamp=doc_time,
recency_bias=recency_bias,
semantic_identifier=doc_sem_id,
text=f"{username}: {cleaned_text}",
highlighted_texts=highlighted_texts,
slack_score=score,
)
)
return slack_messages
def merge_slack_messages(
slack_messages: list[list[SlackMessage]],
) -> tuple[list[SlackMessage], dict[str, SlackMessage]]:
merged_messages: list[SlackMessage] = []
docid_to_message: dict[str, SlackMessage] = {}
for messages in slack_messages:
for message in messages:
if message.document_id in docid_to_message:
# update the score and highlighted texts, rest should be identical
docid_to_message[message.document_id].slack_score = max(
docid_to_message[message.document_id].slack_score,
message.slack_score,
)
docid_to_message[message.document_id].highlighted_texts.update(
message.highlighted_texts
)
continue
# add the message to the list
docid_to_message[message.document_id] = message
merged_messages.append(message)
# re-sort by score
merged_messages.sort(key=lambda x: x.slack_score, reverse=True)
return merged_messages, docid_to_message
def get_contextualized_thread_text(message: SlackMessage, access_token: str) -> str:
"""
Retrieves the initial thread message as well as the text following the message
and combines them into a single string. If the slack query fails, returns the
original message text.
The idea is that the message (the one that actually matched the search), the
initial thread message, and the replies to the message are important in answering
the user's query.
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# if it's not a thread, return the message text
if thread_id is None:
return message.text
# get the thread messages
slack_client = WebClient(token=access_token)
try:
response = slack_client.conversations_replies(
channel=channel_id,
ts=thread_id,
)
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
return message.text
# make sure we didn't get an empty response or a single message (not a thread)
if len(messages) <= 1:
return message.text
# add the initial thread message
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# add the message (unless it's the initial message)
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
else:
message_id_idx = next(
(i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0
)
if not message_id_idx:
return thread_text
# add the message
thread_text += "\n..." if message_id_idx > 1 else ""
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n<@{msg_sender}>: {msg_text}"
# add the following replies to the thread text
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
msg_sender = msg.get("user", "")
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# replace user ids with names in the thread text
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
for userid in userids:
try:
response = slack_client.users_profile_get(user=userid)
response.validate()
profile: dict[str, Any] = response.get("profile", {})
name: str | None = profile.get("real_name") or profile.get("email")
except SlackApiError as e:
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
continue
if not name:
continue
thread_text = thread_text.replace(f"<@{userid}>", name)
return thread_text
def convert_slack_score(slack_score: float) -> float:
"""
Convert slack score to a score between 0 and 1.
Will affect UI ordering and LLM ordering, but not the pruning.
I.e., should have very little effect on the search/answer quality.
"""
return max(0.0, min(1.0, slack_score / 90_000))
@log_function_time(print_only=True)
def slack_retrieval(
query: SearchQuery,
access_token: str,
db_session: Session,
limit: int | None = None,
) -> list[InferenceChunk]:
# query slack
_, fast_llm = get_default_llms()
query_strings = build_slack_queries(query, fast_llm)
results: list[list[SlackMessage]] = run_functions_tuples_in_parallel(
[
(query_slack, (query_string, query, access_token, limit))
for query_string in query_strings
]
)
slack_messages, docid_to_message = merge_slack_messages(results)
slack_messages = slack_messages[: limit or len(slack_messages)]
if not slack_messages:
return []
# contextualize the slack messages
thread_texts: list[str] = run_functions_tuples_in_parallel(
[
(get_contextualized_thread_text, (slack_message, access_token))
for slack_message in slack_messages
]
)
for slack_message, thread_text in zip(slack_messages, thread_texts):
slack_message.text = thread_text
# get the highlighted texts from shortest to longest
highlighted_texts: set[str] = set()
for slack_message in slack_messages:
highlighted_texts.update(slack_message.highlighted_texts)
sorted_highlighted_texts = sorted(highlighted_texts, key=len)
# convert slack messages to index documents
index_docs: list[IndexingDocument] = []
for slack_message in slack_messages:
section: TextSection = TextSection(
text=slack_message.text, link=slack_message.link
)
index_docs.append(
IndexingDocument(
id=slack_message.document_id,
sections=[section],
processed_sections=[section],
source=DocumentSource.SLACK,
title=slack_message.semantic_identifier,
semantic_identifier=slack_message.semantic_identifier,
metadata=slack_message.metadata,
doc_updated_at=slack_message.timestamp,
)
)
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(index_docs)
# prune chunks without any highlighted texts
relevant_chunks: list[DocAwareChunk] = []
chunkid_to_match_highlight: dict[str, str] = {}
for chunk in chunks:
match_highlight = chunk.content
for highlight in sorted_highlighted_texts: # faster than re sub
match_highlight = match_highlight.replace(
highlight, f"<hi>{highlight}</hi>"
)
# if nothing got replaced, the chunk is irrelevant
if len(match_highlight) == len(chunk.content):
continue
chunk_id = f"{chunk.source_document.id}__{chunk.chunk_id}"
relevant_chunks.append(chunk)
chunkid_to_match_highlight[chunk_id] = match_highlight
if limit and len(relevant_chunks) >= limit:
break
# convert to inference chunks
top_chunks: list[InferenceChunk] = []
for chunk in relevant_chunks:
document_id = chunk.source_document.id
chunk_id = f"{document_id}__{chunk.chunk_id}"
top_chunks.append(
InferenceChunk(
chunk_id=chunk.chunk_id,
blurb=chunk.blurb,
content=chunk.content,
source_links=chunk.source_links,
image_file_id=chunk.image_file_id,
section_continuation=chunk.section_continuation,
semantic_identifier=docid_to_message[document_id].semantic_identifier,
document_id=document_id,
source_type=DocumentSource.SLACK,
title=chunk.title_prefix,
boost=0,
recency_bias=docid_to_message[document_id].recency_bias,
score=convert_slack_score(docid_to_message[document_id].slack_score),
hidden=False,
is_relevant=None,
relevance_explanation="",
metadata=docid_to_message[document_id].metadata,
match_highlights=[chunkid_to_match_highlight[chunk_id]],
doc_summary="",
chunk_context="",
updated_at=docid_to_message[document_id].timestamp,
is_federated=True,
)
)
return top_chunks

View File

@@ -154,7 +154,6 @@ class SearchRequest(ChunkContext):
query: str
expanded_queries: QueryExpansions | None = None
original_query: str | None = None
search_type: SearchType = SearchType.SEMANTIC
@@ -206,7 +205,6 @@ class SearchQuery(ChunkContext):
precomputed_query_embedding: Embedding | None = None
expanded_queries: QueryExpansions | None = None
original_query: str | None
class RetrievalDetails(ChunkContext):
@@ -254,8 +252,6 @@ class InferenceChunk(BaseChunk):
secondary_owners: list[str] | None = None
large_chunk_reference_ids: list[int] = Field(default_factory=list)
is_federated: bool = False
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"

View File

@@ -158,7 +158,6 @@ class SearchPipeline:
# These chunks do not include large chunks and have been deduped
self._retrieved_chunks = retrieve_chunks(
query=self.search_query,
user_id=self.user.id if self.user else None,
document_index=self.document_index,
db_session=self.db_session,
retrieval_metrics_callback=self.retrieval_metrics_callback,

View File

@@ -25,6 +25,7 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
@@ -69,16 +70,18 @@ def update_image_sections_with_query(
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_id}"
)
file_record = get_default_file_store().read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(f"Successfully loaded image data for {chunk.image_file_id}")
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_id), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_id}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_id}"
)
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),

View File

@@ -250,7 +250,6 @@ def retrieval_preprocessing(
return SearchQuery(
query=query,
original_query=search_request.original_query,
processed_keywords=processed_keywords,
search_type=SearchType.KEYWORD if is_keyword else SearchType.SEMANTIC,
evaluation_type=llm_evaluation_type,

View File

@@ -1,6 +1,5 @@
import string
from collections.abc import Callable
from uuid import UUID
import nltk # type:ignore
from sqlalchemy.orm import Session
@@ -18,18 +17,15 @@ from onyx.context.search.models import SearchQuery
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import get_query_embeddings
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -37,6 +33,9 @@ from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -116,6 +115,34 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -323,7 +350,6 @@ def _simplify_text(text: str) -> str:
def retrieve_chunks(
query: SearchQuery,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session,
retrieval_metrics_callback: (
@@ -333,34 +359,14 @@ def retrieve_chunks(
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion(db_session)
run_queries: list[tuple[Callable, tuple]] = []
source_filters = (
set(query.filters.source_type) if query.filters.source_type else None
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session, user_id, query.filters.source_type, query.filters.document_set
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
for federated_retrieval_info in federated_retrieval_infos
)
for federated_retrieval_info in federated_retrieval_infos:
run_queries.append((federated_retrieval_info.retrieval_function, (query,)))
# Normal retrieval
normal_search_enabled = (source_filters is None) or (
len(set(source_filters) - federated_sources) > 0
)
if normal_search_enabled and (
not multilingual_expansion or "\n" in query.query or "\r" in query.query
):
# Don't do query expansion on complex queries, rephrasings likely would not work well
run_queries.append((doc_index_retrieval, (query, document_index, db_session)))
elif normal_search_enabled:
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, db_session=db_session
)
else:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
@@ -387,11 +393,13 @@ def retrieve_chunks(
deep=True,
)
run_queries.append(
(doc_index_retrieval, (q_copy, document_index, db_session))
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.warning(

View File

@@ -4,7 +4,6 @@ from typing import TypeVar
from nltk.corpus import stopwords # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceChunk
@@ -13,13 +12,7 @@ from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SavedSearchDocWithContent
from onyx.context.search.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.search_settings import get_current_search_settings
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -167,21 +160,3 @@ def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
except Exception as e:
logger.warning(f"Error removing stop words and punctuation: {e}")
return keywords
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]

View File

@@ -148,10 +148,7 @@ def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor
if api_key_user is None:
raise RuntimeError("API Key does not have associated user.")
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = get_current_tenant_id()
new_api_key = generate_api_key(tenant_id)
new_api_key = generate_api_key()
existing_api_key.hashed_api_key = hash_api_key(new_api_key)
existing_api_key.api_key_display = build_displayable_api_key(new_api_key)
db_session.commit()

Some files were not shown because too many files have changed in this diff Show More