mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 16:02:45 +00:00
Compare commits
1 Commits
cli/v0.2.1
...
release/v.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cd96de146e |
197
.vscode/launch.template.jsonc
vendored
197
.vscode/launch.template.jsonc
vendored
@@ -23,12 +23,10 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery background",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -42,16 +40,29 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery kg_processing",
|
||||
"Celery monitoring",
|
||||
"Celery user_file_processing",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat",
|
||||
"Celery monitoring",
|
||||
"Celery user file processing"
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
@@ -199,6 +210,35 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -221,13 +261,100 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.kg_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=kg_processing@%n",
|
||||
"-Q",
|
||||
"kg_processing"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery kg_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user_file_processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
"type": "debugpy",
|
||||
@@ -311,58 +438,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery beat Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery monitoring",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=solo",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery monitoring Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery user file processing",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--pool=threads",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync"
|
||||
],
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user file processing Console"
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"consoleName": "Pytest",
|
||||
|
||||
32
AGENTS.md
32
AGENTS.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (single worker process)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 6 threads
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
|
||||
39
CLAUDE.md
39
CLAUDE.md
@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Single thread (monitoring doesn't need parallelism)
|
||||
- Cloud-specific monitoring tasks
|
||||
|
||||
8. **Beat Worker** (`beat`)
|
||||
8. **User File Processing Worker** (`user_file_processing`)
|
||||
- Processes user-uploaded files
|
||||
- Handles user file indexing and project synchronization
|
||||
- Configurable concurrency
|
||||
|
||||
9. **Beat Worker** (`beat`)
|
||||
- Celery's scheduler for periodic tasks
|
||||
- Uses DynamicTenantScheduler for multi-tenant support
|
||||
- Schedules tasks like:
|
||||
@@ -82,11 +87,39 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
|
||||
middleware layer that automatically finds the appropriate tenant ID when sending tasks
|
||||
via Celery Beat.
|
||||
- **Task Prioritization**: High, Medium, Low priority queues
|
||||
- **Monitoring**: Built-in heartbeat and liveness checking
|
||||
|
||||
@@ -111,6 +111,8 @@ COPY ./static /app/static
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY ./assets /app/assets
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
"""add permission sync attempt tables
|
||||
|
||||
Revision ID: 03d710ccf29c
|
||||
Revises: 96a5702df6aa
|
||||
Create Date: 2025-09-11 13:30:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d710ccf29c" # Generate a new unique ID
|
||||
down_revision = "96a5702df6aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the permission sync status enum
|
||||
permission_sync_status_enum = sa.Enum(
|
||||
"not_started",
|
||||
"in_progress",
|
||||
"success",
|
||||
"canceled",
|
||||
"failed",
|
||||
"completed_with_errors",
|
||||
name="permissionsyncstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Create doc_permission_sync_attempt table
|
||||
op.create_table(
|
||||
"doc_permission_sync_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", permission_sync_status_enum, nullable=False),
|
||||
sa.Column("total_docs_synced", sa.Integer(), nullable=True),
|
||||
sa.Column("docs_with_permission_errors", sa.Integer(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for doc_permission_sync_attempt
|
||||
op.create_index(
|
||||
"ix_doc_permission_sync_attempt_time_created",
|
||||
"doc_permission_sync_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_permission_sync_attempt_latest_for_cc_pair",
|
||||
"doc_permission_sync_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_permission_sync_attempt_status_time",
|
||||
"doc_permission_sync_attempt",
|
||||
["status", sa.text("time_finished DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# Create external_group_permission_sync_attempt table
|
||||
# connector_credential_pair_id is nullable - group syncs can be global (e.g., Confluence)
|
||||
op.create_table(
|
||||
"external_group_permission_sync_attempt",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column("status", permission_sync_status_enum, nullable=False),
|
||||
sa.Column("total_users_processed", sa.Integer(), nullable=True),
|
||||
sa.Column("total_groups_processed", sa.Integer(), nullable=True),
|
||||
sa.Column("total_group_memberships_synced", sa.Integer(), nullable=True),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# Create indexes for external_group_permission_sync_attempt
|
||||
op.create_index(
|
||||
"ix_external_group_permission_sync_attempt_time_created",
|
||||
"external_group_permission_sync_attempt",
|
||||
["time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_group_sync_attempt_cc_pair_time",
|
||||
"external_group_permission_sync_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_group_sync_attempt_status_time",
|
||||
"external_group_permission_sync_attempt",
|
||||
["status", sa.text("time_finished DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index(
|
||||
"ix_group_sync_attempt_status_time",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_group_sync_attempt_cc_pair_time",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_external_group_permission_sync_attempt_time_created",
|
||||
table_name="external_group_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_permission_sync_attempt_status_time",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_permission_sync_attempt_latest_for_cc_pair",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_doc_permission_sync_attempt_time_created",
|
||||
table_name="doc_permission_sync_attempt",
|
||||
)
|
||||
|
||||
# Drop tables
|
||||
op.drop_table("external_group_permission_sync_attempt")
|
||||
op.drop_table("doc_permission_sync_attempt")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add queries and is web fetch to iteration answer
|
||||
|
||||
Revision ID: 6f4f86aef280
|
||||
Revises: 96a5702df6aa
|
||||
Create Date: 2025-10-14 18:08:30.920123
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6f4f86aef280"
|
||||
down_revision = "03d710ccf29c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add is_web_fetch column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
# Add queries column
|
||||
op.add_column(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("queries", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("research_agent_iteration_sub_step", "queries")
|
||||
op.drop_column("research_agent_iteration_sub_step", "is_web_fetch")
|
||||
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
45
backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""mcp_tool_enabled
|
||||
|
||||
Revision ID: 96a5702df6aa
|
||||
Revises: 40926a4dab77
|
||||
Create Date: 2025-10-09 12:10:21.733097
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "96a5702df6aa"
|
||||
down_revision = "40926a4dab77"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_tool_mcp_server_enabled",
|
||||
"tool",
|
||||
["mcp_server_id", "enabled"],
|
||||
)
|
||||
# Remove the server default so application controls defaulting
|
||||
op.alter_column("tool", "enabled", server_default=None)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(DELETE_DISABLED_TOOLS_SQL)
|
||||
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
|
||||
op.drop_column("tool", "enabled")
|
||||
@@ -1,8 +1,13 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
@@ -10,6 +15,7 @@ from fastapi import status
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
@@ -30,43 +36,156 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
|
||||
|
||||
|
||||
class PublicKeyFormat(Enum):
|
||||
JWKS = "jwks"
|
||||
PEM = "pem"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_public_key() -> str | None:
|
||||
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
|
||||
"""Fetch and cache the raw JWT verification material."""
|
||||
if JWT_PUBLIC_KEY_URL is None:
|
||||
logger.error("JWT_PUBLIC_KEY_URL is not set")
|
||||
return None
|
||||
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
try:
|
||||
response = requests.get(JWT_PUBLIC_KEY_URL)
|
||||
response.raise_for_status()
|
||||
except requests.RequestException as exc:
|
||||
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
|
||||
return None
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
raw_body = response.text
|
||||
body_lstripped = raw_body.lstrip()
|
||||
|
||||
if "application/json" in content_type or body_lstripped.startswith("{"):
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError:
|
||||
logger.error("JWT public key URL returned invalid JSON")
|
||||
return None
|
||||
|
||||
if isinstance(data, dict) and "keys" in data:
|
||||
return data, PublicKeyFormat.JWKS
|
||||
|
||||
logger.error(
|
||||
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
|
||||
)
|
||||
return None
|
||||
|
||||
body = raw_body.strip()
|
||||
if not body:
|
||||
logger.error("JWT public key URL returned an empty response")
|
||||
return None
|
||||
|
||||
return body, PublicKeyFormat.PEM
|
||||
|
||||
|
||||
def get_public_key(token: str) -> RSAPublicKey | str | None:
|
||||
"""Return the concrete public key used to verify the provided JWT token."""
|
||||
payload = _fetch_public_key_payload()
|
||||
if payload is None:
|
||||
logger.error("Failed to retrieve public key payload")
|
||||
return None
|
||||
|
||||
key_material, key_format = payload
|
||||
|
||||
if key_format is PublicKeyFormat.JWKS:
|
||||
jwks_data = cast(dict[str, Any], key_material)
|
||||
return _resolve_public_key_from_jwks(token, jwks_data)
|
||||
|
||||
return cast(str, key_material)
|
||||
|
||||
|
||||
def _resolve_public_key_from_jwks(
|
||||
token: str, jwks_payload: dict[str, Any]
|
||||
) -> RSAPublicKey | None:
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except PyJWTError as e:
|
||||
logger.error(f"Unable to parse JWT header: {str(e)}")
|
||||
return None
|
||||
|
||||
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
|
||||
if not keys:
|
||||
logger.error("JWKS payload did not contain any keys")
|
||||
return None
|
||||
|
||||
kid = header.get("kid")
|
||||
thumbprint = header.get("x5t")
|
||||
|
||||
candidates = []
|
||||
if kid:
|
||||
candidates = [k for k in keys if k.get("kid") == kid]
|
||||
if not candidates and thumbprint:
|
||||
candidates = [k for k in keys if k.get("x5t") == thumbprint]
|
||||
if not candidates and len(keys) == 1:
|
||||
candidates = keys
|
||||
|
||||
if not candidates:
|
||||
logger.warning(
|
||||
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
|
||||
)
|
||||
return None
|
||||
|
||||
if len(candidates) > 1:
|
||||
logger.warning(
|
||||
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
|
||||
kid,
|
||||
)
|
||||
|
||||
jwk = candidates[0]
|
||||
try:
|
||||
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
|
||||
try:
|
||||
public_key_pem = get_public_key()
|
||||
if public_key_pem is None:
|
||||
logger.error("Failed to retrieve public key")
|
||||
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
|
||||
public_key = get_public_key(token)
|
||||
if public_key is None:
|
||||
logger.error("Unable to resolve a public key for JWT verification")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_aud": False},
|
||||
)
|
||||
except InvalidTokenError as e:
|
||||
logger.error(f"Invalid JWT token: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
|
||||
_fetch_public_key_payload.cache_clear()
|
||||
continue
|
||||
return None
|
||||
|
||||
payload = jwt_decode(
|
||||
token,
|
||||
public_key_pem,
|
||||
algorithms=["RS256"],
|
||||
audience=None,
|
||||
)
|
||||
email = payload.get("email")
|
||||
if email:
|
||||
result = await async_db_session.execute(
|
||||
select(User).where(func.lower(User.email) == func.lower(email))
|
||||
)
|
||||
return result.scalars().first()
|
||||
except InvalidTokenError:
|
||||
logger.error("Invalid JWT token")
|
||||
get_public_key.cache_clear()
|
||||
except PyJWTError as e:
|
||||
logger.error(f"JWT decoding error: {str(e)}")
|
||||
get_public_key.cache_clear()
|
||||
logger.warning(
|
||||
"JWT token decoded successfully but no email claim found; skipping auth"
|
||||
)
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
12
backend/ee/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
@@ -1,123 +1,4 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -56,6 +56,12 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import mark_doc_permission_sync_attempt_failed
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_doc_permission_sync_attempt_in_progress,
|
||||
)
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
@@ -113,6 +119,14 @@ def _get_fence_validation_block_expiration() -> int:
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
def _fail_doc_permission_sync_attempt(attempt_id: int, error_msg: str) -> None:
|
||||
"""Helper to mark a doc permission sync attempt as failed with an error message."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_doc_permission_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
@@ -379,6 +393,15 @@ def connector_permission_sync_generator_task(
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt_id = create_doc_permission_sync_attempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Created doc permission sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
@@ -389,22 +412,28 @@ def connector_permission_sync_generator_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
error_msg = (
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -432,9 +461,11 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
error_msg = (
|
||||
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
task_logger.warning(error_msg)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -470,11 +501,15 @@ def connector_permission_sync_generator_task(
|
||||
source_type = cc_pair.connector.source
|
||||
sync_config = get_source_perm_sync_config(source_type)
|
||||
if sync_config is None:
|
||||
logger.error(f"No sync config found for {source_type}")
|
||||
error_msg = f"No sync config found for {source_type}"
|
||||
logger.error(error_msg)
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
if sync_config.doc_sync_config is None:
|
||||
if sync_config.censoring_config:
|
||||
error_msg = f"Doc sync config is None but censoring config exists for {source_type}"
|
||||
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
@@ -483,6 +518,8 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
mark_doc_permission_sync_attempt_in_progress(attempt_id, db_session)
|
||||
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
@@ -533,8 +570,9 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
redis_connector.permissions.update_db(
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
source_string=source_type,
|
||||
@@ -542,11 +580,23 @@ def connector_permission_sync_generator_task(
|
||||
credential_id=cc_pair.credential.id,
|
||||
task_logger=task_logger,
|
||||
)
|
||||
tasks_generated += 1
|
||||
tasks_generated += result.num_updated
|
||||
docs_with_errors += result.num_errors
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated} docs_with_errors={docs_with_errors}"
|
||||
)
|
||||
|
||||
complete_doc_permission_sync_attempt(
|
||||
db_session=db_session,
|
||||
attempt_id=attempt_id,
|
||||
total_docs_synced=tasks_generated,
|
||||
docs_with_permission_errors=docs_with_errors,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Completed doc permission sync attempt {attempt_id}: "
|
||||
f"{tasks_generated} docs, {docs_with_errors} errors"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
@@ -561,6 +611,11 @@ def connector_permission_sync_generator_task(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_doc_permission_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
|
||||
@@ -49,6 +49,16 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.permission_sync_attempt import complete_external_group_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
create_external_group_sync_attempt,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_external_group_sync_attempt_failed,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
mark_external_group_sync_attempt_in_progress,
|
||||
)
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -70,6 +80,14 @@ logger = setup_logger()
|
||||
_EXTERNAL_GROUP_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _fail_external_group_sync_attempt(attempt_id: int, error_msg: str) -> None:
|
||||
"""Helper to mark an external group sync attempt as failed with an error message."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_external_group_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the fence validation block signal.
|
||||
@@ -449,6 +467,16 @@ def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt_id = create_external_group_sync_attempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
logger.info(
|
||||
f"Created external group sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
@@ -463,11 +491,13 @@ def _perform_external_group_sync(
|
||||
if sync_config is None:
|
||||
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
_fail_external_group_sync_attempt(attempt_id, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
if sync_config.group_sync_config is None:
|
||||
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
_fail_external_group_sync_attempt(attempt_id, msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
|
||||
@@ -477,14 +507,27 @@ def _perform_external_group_sync(
|
||||
)
|
||||
mark_old_external_groups_as_stale(db_session, cc_pair_id)
|
||||
|
||||
# Mark attempt as in progress
|
||||
mark_external_group_sync_attempt_in_progress(attempt_id, db_session)
|
||||
logger.info(f"Marked external group sync attempt {attempt_id} as in progress")
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_group_batch: list[ExternalUserGroup] = []
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
total_groups_processed += 1
|
||||
total_group_memberships_synced += len(external_user_group.user_emails)
|
||||
seen_users = seen_users.union(external_user_group.user_emails)
|
||||
|
||||
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
|
||||
logger.debug(
|
||||
f"New external user groups: {external_user_group_batch}"
|
||||
@@ -506,6 +549,13 @@ def _perform_external_group_sync(
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
|
||||
# Mark as failed (this also updates progress to show partial progress)
|
||||
mark_external_group_sync_attempt_failed(
|
||||
attempt_id, db_session, error_message=error_msg
|
||||
)
|
||||
|
||||
# TODO: add some notification to the admins here
|
||||
logger.exception(
|
||||
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
@@ -517,6 +567,24 @@ def _perform_external_group_sync(
|
||||
)
|
||||
remove_stale_external_groups(db_session, cc_pair_id)
|
||||
|
||||
# Calculate total unique users processed
|
||||
total_users_processed = len(seen_users)
|
||||
|
||||
# Complete the sync attempt with final progress
|
||||
complete_external_group_sync_attempt(
|
||||
db_session=db_session,
|
||||
attempt_id=attempt_id,
|
||||
total_users_processed=total_users_processed,
|
||||
total_groups_processed=total_groups_processed,
|
||||
total_group_memberships_synced=total_group_memberships_synced,
|
||||
errors_encountered=0,
|
||||
)
|
||||
logger.info(
|
||||
f"Completed external group sync attempt {attempt_id}: "
|
||||
f"{total_groups_processed} groups, {total_users_processed} users, "
|
||||
f"{total_group_memberships_synced} memberships"
|
||||
)
|
||||
|
||||
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)
|
||||
|
||||
|
||||
|
||||
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
119
backend/ee/onyx/background/celery/tasks/query_history/tasks.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
|
||||
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
|
||||
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tasks import delete_task_with_id
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import mark_task_as_started_with_id
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def export_query_history_task(
|
||||
self: Task,
|
||||
*,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
start_time: datetime,
|
||||
# Need to include the tenant_id since the TenantAwareTask needs this
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
if not self.request.id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
task_id = self.request.id
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream,
|
||||
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
|
||||
)
|
||||
writer.writeheader()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
mark_task_as_started_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
snapshot_generator = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
for snapshot in snapshot_generator:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
writer.writerows(
|
||||
qa_pair.to_json()
|
||||
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
|
||||
snapshot
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to export query history with {task_id=}")
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
report_name = construct_query_history_report_name(task_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
stream.seek(0)
|
||||
get_default_file_store().save_file(
|
||||
content=stream,
|
||||
display_name=report_name,
|
||||
file_origin=FileOrigin.QUERY_HISTORY_CSV,
|
||||
file_type=FileType.CSV,
|
||||
file_metadata={
|
||||
"start": start.isoformat(),
|
||||
"end": end.isoformat(),
|
||||
"start_time": start_time.isoformat(),
|
||||
},
|
||||
file_id=report_name,
|
||||
)
|
||||
|
||||
delete_task_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to save query history export file; {report_name=}"
|
||||
)
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
@@ -73,6 +73,12 @@ def fetch_per_user_query_analytics(
|
||||
ChatSession.user_id,
|
||||
)
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
# Include chats that have no explicit feedback instead of dropping them
|
||||
.join(
|
||||
ChatMessageFeedback,
|
||||
ChatMessageFeedback.chat_message_id == ChatMessage.id,
|
||||
isouter=True,
|
||||
)
|
||||
.where(
|
||||
ChatMessage.time_sent >= start,
|
||||
)
|
||||
|
||||
0
backend/ee/onyx/feature_flags/__init__.py
Normal file
0
backend/ee/onyx/feature_flags/__init__.py
Normal file
15
backend/ee/onyx/feature_flags/factory.py
Normal file
15
backend/ee/onyx/feature_flags/factory.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
|
||||
|
||||
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
|
||||
"""
|
||||
Get the PostHog feature flag provider instance.
|
||||
|
||||
This is the EE implementation that gets loaded by the versioned
|
||||
implementation loader.
|
||||
|
||||
Returns:
|
||||
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
|
||||
"""
|
||||
return PostHogFeatureFlagProvider()
|
||||
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
54
backend/ee/onyx/feature_flags/posthog_provider.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PostHogFeatureFlagProvider(FeatureFlagProvider):
|
||||
"""
|
||||
PostHog-based feature flag provider.
|
||||
|
||||
Uses PostHog's feature flag API to determine if features are enabled
|
||||
for specific users. Only active in multi-tenant mode.
|
||||
"""
|
||||
|
||||
def feature_enabled(
|
||||
self,
|
||||
flag_key: str,
|
||||
user_id: UUID,
|
||||
user_properties: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user via PostHog.
|
||||
|
||||
Args:
|
||||
flag_key: The identifier for the feature flag to check
|
||||
user_id: The unique identifier for the user
|
||||
user_properties: Optional dictionary of user properties/attributes
|
||||
that may influence flag evaluation
|
||||
|
||||
Returns:
|
||||
True if the feature is enabled for the user, False otherwise.
|
||||
"""
|
||||
try:
|
||||
posthog.set(
|
||||
distinct_id=user_id,
|
||||
properties=user_properties,
|
||||
)
|
||||
is_enabled = posthog.feature_enabled(
|
||||
flag_key,
|
||||
str(user_id),
|
||||
person_properties=user_properties,
|
||||
)
|
||||
|
||||
return bool(is_enabled) if is_enabled is not None else False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
|
||||
)
|
||||
return False
|
||||
22
backend/ee/onyx/utils/posthog_client.py
Normal file
22
backend/ee/onyx/utils/posthog_client.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
@@ -1,27 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
|
||||
@@ -100,9 +100,14 @@ class IterationAnswer(BaseModel):
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
|
||||
@@ -74,6 +74,7 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -120,7 +121,9 @@ def _get_available_tools(
|
||||
else:
|
||||
include_kg = False
|
||||
|
||||
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
|
||||
tool_dict: dict[int, Tool] = {
|
||||
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
|
||||
}
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
@@ -488,6 +491,7 @@ def clarifier(
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
|
||||
@@ -199,6 +199,7 @@ def save_iteration(
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
|
||||
@@ -180,6 +180,7 @@ def save_iteration(
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
|
||||
@@ -2,30 +2,28 @@ from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
# TODO Dependency inject for testing
|
||||
class ExaClient(InternetSearchProvider):
|
||||
class ExaClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="fast",
|
||||
livecrawl="never",
|
||||
type="auto",
|
||||
highlights=HighlightsContentsOptions(
|
||||
num_sentences=2,
|
||||
highlights_per_url=1,
|
||||
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
|
||||
]
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
urls=urls,
|
||||
text=True,
|
||||
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
|
||||
)
|
||||
|
||||
return [
|
||||
InternetContent(
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
|
||||
@@ -4,13 +4,13 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
import requests
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
@@ -20,7 +20,7 @@ SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
|
||||
class SerperClient(InternetSearchProvider):
|
||||
class SerperClient(WebSearchProvider):
|
||||
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
|
||||
self.headers = {
|
||||
"X-API-KEY": api_key,
|
||||
@@ -28,7 +28,7 @@ class SerperClient(InternetSearchProvider):
|
||||
}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
}
|
||||
@@ -45,7 +45,7 @@ class SerperClient(InternetSearchProvider):
|
||||
organic_results = results["organic"]
|
||||
|
||||
return [
|
||||
InternetSearchResult(
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["link"],
|
||||
snippet=result["snippet"],
|
||||
@@ -55,17 +55,17 @@ class SerperClient(InternetSearchProvider):
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
# Serper can responds with 500s regularly. We want to retry,
|
||||
# but in the event of failure, return an unsuccesful scrape.
|
||||
def safe_get_webpage_content(url: str) -> InternetContent:
|
||||
def safe_get_webpage_content(url: str) -> WebContent:
|
||||
try:
|
||||
return self._get_webpage_content(url)
|
||||
except Exception:
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
@@ -77,7 +77,7 @@ class SerperClient(InternetSearchProvider):
|
||||
return list(e.map(safe_get_webpage_content, urls))
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def _get_webpage_content(self, url: str) -> InternetContent:
|
||||
def _get_webpage_content(self, url: str) -> WebContent:
|
||||
payload = {
|
||||
"url": url,
|
||||
}
|
||||
@@ -90,7 +90,7 @@ class SerperClient(InternetSearchProvider):
|
||||
|
||||
# 400 returned when serper cannot scrape
|
||||
if response.status_code == 400:
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
@@ -122,7 +122,7 @@ class SerperClient(InternetSearchProvider):
|
||||
except Exception:
|
||||
published_date = None
|
||||
|
||||
return InternetContent(
|
||||
return WebContent(
|
||||
title=title or "",
|
||||
link=response_url,
|
||||
full_content=text or "",
|
||||
|
||||
@@ -7,7 +7,7 @@ from langsmith import traceable
|
||||
|
||||
from onyx.agents.agent_search.dr.models import WebSearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
@@ -75,15 +75,15 @@ def web_search(
|
||||
raise ValueError("No internet search provider found")
|
||||
|
||||
@traceable(name="Search Provider API Call")
|
||||
def _search(search_query: str) -> list[InternetSearchResult]:
|
||||
search_results: list[InternetSearchResult] = []
|
||||
def _search(search_query: str) -> list[WebSearchResult]:
|
||||
search_results: list[WebSearchResult] = []
|
||||
try:
|
||||
search_results = provider.search(search_query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing search: {e}")
|
||||
return search_results
|
||||
|
||||
search_results: list[InternetSearchResult] = _search(search_query)
|
||||
search_results: list[WebSearchResult] = _search(search_query)
|
||||
search_results_text = "\n\n".join(
|
||||
[
|
||||
f"{i}. {result.title}\n URL: {result.link}\n"
|
||||
|
||||
@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
|
||||
InternetSearchInput,
|
||||
@@ -23,7 +23,7 @@ def dedup_urls(
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> InternetSearchInput:
|
||||
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
|
||||
unique_results_by_link: dict[str, InternetSearchResult] = {}
|
||||
unique_results_by_link: dict[str, WebSearchResult] = {}
|
||||
for query, result in state.results_to_open:
|
||||
branch_questions_to_urls[query].append(result.link)
|
||||
if result.link not in unique_results_by_link:
|
||||
|
||||
@@ -13,7 +13,7 @@ class ProviderType(Enum):
|
||||
EXA = "exa"
|
||||
|
||||
|
||||
class InternetSearchResult(BaseModel):
|
||||
class WebSearchResult(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
author: str | None = None
|
||||
@@ -21,7 +21,7 @@ class InternetSearchResult(BaseModel):
|
||||
snippet: str | None = None
|
||||
|
||||
|
||||
class InternetContent(BaseModel):
|
||||
class WebContent(BaseModel):
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
@@ -29,11 +29,11 @@ class InternetContent(BaseModel):
|
||||
scrape_successful: bool = True
|
||||
|
||||
|
||||
class InternetSearchProvider(ABC):
|
||||
class WebSearchProvider(ABC):
|
||||
@abstractmethod
|
||||
def search(self, query: str) -> list[InternetSearchResult]:
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contents(self, urls: list[str]) -> list[InternetContent]:
|
||||
def contents(self, urls: list[str]) -> list[WebContent]:
|
||||
pass
|
||||
|
||||
@@ -5,13 +5,13 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client imp
|
||||
SerperClient,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchProvider,
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.configs.chat_configs import EXA_API_KEY
|
||||
from onyx.configs.chat_configs import SERPER_API_KEY
|
||||
|
||||
|
||||
def get_default_provider() -> InternetSearchProvider | None:
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
if EXA_API_KEY:
|
||||
return ExaClient()
|
||||
if SERPER_API_KEY:
|
||||
|
||||
@@ -4,13 +4,13 @@ from typing import Annotated
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class InternetSearchInput(SubAgentInput):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
parallelization_nr: int = 0
|
||||
branch_question: Annotated[str, lambda x, y: y] = ""
|
||||
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
|
||||
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
|
||||
|
||||
|
||||
class InternetSearchUpdate(LoggerUpdate):
|
||||
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
|
||||
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
|
||||
|
||||
|
||||
class FetchInput(SubAgentInput):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetContent,
|
||||
WebContent,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
|
||||
InternetSearchResult,
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_content(
|
||||
result: InternetContent,
|
||||
result: WebContent,
|
||||
) -> InferenceSection:
|
||||
truncated_content = truncate_search_result_content(result.full_content)
|
||||
return InferenceSection(
|
||||
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
|
||||
|
||||
|
||||
def dummy_inference_section_from_internet_search_result(
|
||||
result: InternetSearchResult,
|
||||
result: WebSearchResult,
|
||||
) -> InferenceSection:
|
||||
return InferenceSection(
|
||||
center_chunk=InferenceChunk(
|
||||
|
||||
@@ -54,6 +54,8 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
@@ -103,6 +105,7 @@ from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
@@ -324,8 +327,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
user = await super().create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user_created = True
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -351,11 +358,42 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def _assign_default_pinned_assistants(
|
||||
self, user: User, db_session: AsyncSession
|
||||
) -> None:
|
||||
if user.pinned_assistants is not None:
|
||||
return
|
||||
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
)
|
||||
.order_by(
|
||||
nulls_last(Persona.display_priority.asc()),
|
||||
Persona.id.asc(),
|
||||
)
|
||||
)
|
||||
default_persona_ids = list(result.scalars().all())
|
||||
if not default_persona_ids:
|
||||
return
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{"pinned_assistants": default_persona_ids},
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -476,6 +514,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -1040,7 +1079,10 @@ async def optional_user(
|
||||
|
||||
# check if an API key is present
|
||||
if user is None:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
try:
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
except ValueError:
|
||||
hashed_api_key = None
|
||||
if hashed_api_key:
|
||||
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
|
||||
|
||||
|
||||
137
backend/onyx/background/celery/apps/background.py
Normal file
137
backend/onyx/background/celery/apps/background.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
23
backend/onyx/background/celery/configs/background.py
Normal file
23
backend/onyx/background/celery/configs/background.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = 4
|
||||
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Monitoring worker specific settings
|
||||
worker_concurrency = 1 # Single worker is sufficient for monitoring
|
||||
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -42,6 +42,12 @@ from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
delete_doc_permission_sync_attempts__no_commit,
|
||||
)
|
||||
from onyx.db.permission_sync_attempt import (
|
||||
delete_external_group_permission_sync_attempts__no_commit,
|
||||
)
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -441,6 +447,16 @@ def monitor_connector_deletion_taskset(
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# permission sync attempts
|
||||
delete_doc_permission_sync_attempts__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
delete_external_group_permission_sync_attempts__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
|
||||
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
10
backend/onyx/background/celery/versioned_apps/background.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -2,11 +2,13 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.answer import Answer
|
||||
@@ -25,12 +27,16 @@ from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.packet_proccessing.process_streamed_packets import (
|
||||
process_streamed_packets,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import (
|
||||
default_build_system_message_v2,
|
||||
)
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message_v2
|
||||
from onyx.chat.turn import fast_chat_turn
|
||||
from onyx.chat.turn.infra.emitter import get_default_emitter
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.chat.user_files.parse_user_files import parse_user_files
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -69,6 +75,8 @@ from onyx.db.projects import get_project_instructions
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.feature_flags.factory import get_default_feature_flag_provider
|
||||
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
@@ -81,6 +89,7 @@ from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -88,6 +97,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
@@ -355,14 +365,12 @@ def stream_chat_message_objects(
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
persona = _get_persona_for_chat_session(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
default_persona=chat_session.persona,
|
||||
)
|
||||
|
||||
# TODO: remove once we have an endpoint for this stuff
|
||||
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
|
||||
|
||||
@@ -732,15 +740,36 @@ def stream_chat_message_objects(
|
||||
and (file.file_id not in project_file_ids)
|
||||
]
|
||||
)
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
feature_flag_provider = get_default_feature_flag_provider()
|
||||
simple_agent_framework_enabled = (
|
||||
feature_flag_provider.feature_enabled_for_user_tenant(
|
||||
flag_key=SIMPLE_AGENT_FRAMEWORK,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
and not new_msg_req.use_agentic_search
|
||||
)
|
||||
prompt_user_message = (
|
||||
default_build_user_message_v2(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
)
|
||||
if simple_agent_framework_enabled
|
||||
else default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
)
|
||||
)
|
||||
system_message = (
|
||||
default_build_system_message_v2(prompt_config, llm.config)
|
||||
if simple_agent_framework_enabled
|
||||
else default_build_system_message(prompt_config, llm.config)
|
||||
)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=prompt_user_message,
|
||||
system_message=system_message,
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
@@ -782,11 +811,21 @@ def stream_chat_message_objects(
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
project_instructions=project_instructions,
|
||||
)
|
||||
if simple_agent_framework_enabled:
|
||||
yield from _fast_message_stream(
|
||||
answer,
|
||||
tools,
|
||||
db_session,
|
||||
get_redis_client(),
|
||||
chat_session_id,
|
||||
reserved_message_id,
|
||||
)
|
||||
else:
|
||||
from onyx.chat.packet_proccessing import process_streamed_packets
|
||||
|
||||
# Process streamed packets using the new packet processing module
|
||||
yield from process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
yield from process_streamed_packets.process_streamed_packets(
|
||||
answer_processed_output=answer.processed_streamed_output,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
@@ -823,6 +862,59 @@ def stream_chat_message_objects(
|
||||
return
|
||||
|
||||
|
||||
# TODO: Refactor this to live somewhere else
|
||||
def _fast_message_stream(
|
||||
answer: Answer,
|
||||
tools: list[Tool],
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
chat_session_id: UUID,
|
||||
reserved_message_id: int,
|
||||
) -> Generator[Packet, None, None]:
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.llm.litellm_singleton import LitellmModel
|
||||
|
||||
image_generation_tool_instance = None
|
||||
okta_profile_tool_instance = None
|
||||
for tool in tools:
|
||||
if isinstance(tool, ImageGenerationTool):
|
||||
image_generation_tool_instance = tool
|
||||
elif isinstance(tool, OktaProfileTool):
|
||||
okta_profile_tool_instance = tool
|
||||
converted_message_history = [
|
||||
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
|
||||
for message in answer.graph_inputs.prompt_builder.build()
|
||||
]
|
||||
emitter = get_default_emitter()
|
||||
return fast_chat_turn.fast_chat_turn(
|
||||
messages=converted_message_history,
|
||||
# TODO: Maybe we can use some DI framework here?
|
||||
dependencies=ChatTurnDependencies(
|
||||
llm_model=LitellmModel(
|
||||
model=answer.graph_tooling.primary_llm.config.model_name,
|
||||
base_url=answer.graph_tooling.primary_llm.config.api_base,
|
||||
api_key=answer.graph_tooling.primary_llm.config.api_key,
|
||||
),
|
||||
llm=answer.graph_tooling.primary_llm,
|
||||
tools=tools_to_function_tools(tools),
|
||||
search_pipeline=answer.graph_tooling.search_tool,
|
||||
image_generation_tool=image_generation_tool_instance,
|
||||
okta_profile_tool=okta_profile_tool_instance,
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
emitter=emitter,
|
||||
),
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
research_type=answer.graph_config.behavior.research_type,
|
||||
)
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
|
||||
@@ -21,8 +21,10 @@ from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT_V2
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
@@ -31,6 +33,33 @@ from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message_v2(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT_V2
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
@@ -52,9 +81,29 @@ def default_build_system_message(
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
|
||||
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
|
||||
|
||||
def default_build_user_message_v2(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
) -> HumanMessage:
|
||||
user_prompt = user_query
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=(
|
||||
build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
)
|
||||
)
|
||||
return user_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
|
||||
56
backend/onyx/chat/stop_signal_checker.py
Normal file
56
backend/onyx/chat/stop_signal_checker.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
"""
|
||||
Set or clear the stop signal fence for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
value: True to set the fence (stop signal), False to clear it
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
if not value:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session should continue (not stopped).
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session to check
|
||||
redis_client: Redis client to use for checking the stop signal
|
||||
|
||||
Returns:
|
||||
True if the session should continue, False if it should stop
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
return not bool(redis_client.exists(fence_key))
|
||||
|
||||
|
||||
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
|
||||
"""
|
||||
Clear the stop signal for a chat session.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: Redis client to use
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
|
||||
redis_client.delete(fence_key)
|
||||
1
backend/onyx/chat/turn/__init__.py
Normal file
1
backend/onyx/chat/turn/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Turn module for chat functionality
|
||||
258
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
258
backend/onyx/chat/turn/fast_chat_turn.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from agents import Agent
|
||||
from agents import ModelSettings
|
||||
from agents import RawResponsesStreamEvent
|
||||
from agents import StopAtTools
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
|
||||
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
|
||||
from onyx.chat.turn.infra.session_sink import save_iteration
|
||||
from onyx.chat.turn.infra.sync_agent_stream_adapter import SyncAgentStream
|
||||
from onyx.chat.turn.models import AgentToolType
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketObj
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
|
||||
|
||||
|
||||
def _fast_chat_turn_core(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
# Dependency injectable arguments for testing
|
||||
starter_global_iteration_responses: list[IterationAnswer] | None = None,
|
||||
starter_cited_documents: list[InferenceSection] | None = None,
|
||||
) -> None:
|
||||
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
dependencies: Chat turn dependencies
|
||||
chat_session_id: Chat session ID
|
||||
message_id: Message ID
|
||||
research_type: Research type
|
||||
global_iteration_responses: Optional list of iteration answers to inject for testing
|
||||
cited_documents: Optional list of cited documents to inject for testing
|
||||
"""
|
||||
reset_cancel_status(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
ctx = ChatTurnContext(
|
||||
run_dependencies=dependencies,
|
||||
aggregated_context=AggregatedDRContext(
|
||||
context="context",
|
||||
cited_documents=starter_cited_documents or [],
|
||||
is_internet_marker_dict={},
|
||||
global_iteration_responses=starter_global_iteration_responses or [],
|
||||
),
|
||||
iteration_instructions=[],
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=message_id,
|
||||
research_type=research_type,
|
||||
)
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
model=dependencies.llm_model,
|
||||
tools=cast(list[AgentToolType], dependencies.tools),
|
||||
model_settings=ModelSettings(
|
||||
temperature=dependencies.llm.config.temperature,
|
||||
include_usage=True,
|
||||
),
|
||||
tool_use_behavior=StopAtTools(stop_at_tool_names=[image_generation_tool.name]),
|
||||
)
|
||||
# By default, the agent can only take 10 turns. For our use case, it should be higher.
|
||||
max_turns = 25
|
||||
agent_stream: SyncAgentStream = SyncAgentStream(
|
||||
agent=agent,
|
||||
input=messages,
|
||||
context=ctx,
|
||||
max_turns=max_turns,
|
||||
)
|
||||
for ev in agent_stream:
|
||||
connected = is_connected(
|
||||
chat_session_id,
|
||||
dependencies.redis_client,
|
||||
)
|
||||
if not connected:
|
||||
_emit_clean_up_packets(dependencies, ctx)
|
||||
agent_stream.cancel()
|
||||
break
|
||||
obj = _default_packet_translation(ev, ctx)
|
||||
if obj:
|
||||
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
|
||||
final_answer = extract_final_answer_from_packets(
|
||||
dependencies.emitter.packet_history
|
||||
)
|
||||
|
||||
all_cited_documents = []
|
||||
if ctx.aggregated_context.global_iteration_responses:
|
||||
context_docs = _gather_context_docs_from_iteration_answers(
|
||||
ctx.aggregated_context.global_iteration_responses
|
||||
)
|
||||
all_cited_documents = context_docs
|
||||
if context_docs and final_answer:
|
||||
_process_citations_for_final_answer(
|
||||
final_answer=final_answer,
|
||||
context_docs=context_docs,
|
||||
dependencies=dependencies,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
save_iteration(
|
||||
db_session=dependencies.db_session,
|
||||
message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
research_type=research_type,
|
||||
ctx=ctx,
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
|
||||
@unified_event_stream
|
||||
def fast_chat_turn(
|
||||
messages: list[dict],
|
||||
dependencies: ChatTurnDependencies,
|
||||
chat_session_id: UUID,
|
||||
message_id: int,
|
||||
research_type: ResearchType,
|
||||
) -> None:
|
||||
"""Main fast chat turn function that calls the core logic with default parameters."""
|
||||
_fast_chat_turn_core(
|
||||
messages,
|
||||
dependencies,
|
||||
chat_session_id,
|
||||
message_id,
|
||||
research_type,
|
||||
starter_global_iteration_responses=None,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
|
||||
def _emit_clean_up_packets(
|
||||
dependencies: ChatTurnDependencies, ctx: ChatTurnContext
|
||||
) -> None:
|
||||
if not (
|
||||
dependencies.emitter.packet_history
|
||||
and dependencies.emitter.packet_history[-1].obj.type == "message_delta"
|
||||
):
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=ctx.current_run_step,
|
||||
obj=MessageStart(
|
||||
type="message_start", content="Cancelled", final_documents=None
|
||||
),
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=ctx.current_run_step, obj=SectionEnd(type="section_end"))
|
||||
)
|
||||
|
||||
|
||||
def _gather_context_docs_from_iteration_answers(
|
||||
iteration_answers: list[IterationAnswer],
|
||||
) -> list[InferenceSection]:
|
||||
"""Gather cited documents from iteration answers for citation processing."""
|
||||
context_docs: list[InferenceSection] = []
|
||||
|
||||
for iteration_answer in iteration_answers:
|
||||
# Extract cited documents from this iteration
|
||||
for inference_section in iteration_answer.cited_documents.values():
|
||||
# Avoid duplicates by checking document_id
|
||||
if not any(
|
||||
doc.center_chunk.document_id
|
||||
== inference_section.center_chunk.document_id
|
||||
for doc in context_docs
|
||||
):
|
||||
context_docs.append(inference_section)
|
||||
|
||||
return context_docs
|
||||
|
||||
|
||||
def _process_citations_for_final_answer(
|
||||
final_answer: str,
|
||||
context_docs: list[InferenceSection],
|
||||
dependencies: ChatTurnDependencies,
|
||||
ctx: ChatTurnContext,
|
||||
) -> None:
|
||||
index = ctx.current_run_step + 1
|
||||
"""Process citations in the final answer and emit citation events."""
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
# Convert InferenceSection objects to LlmDoc objects for citation processing
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
|
||||
|
||||
# Create document ID to rank mappings (simple 1-based indexing)
|
||||
final_doc_id_to_rank_map = DocumentIdOrderMapping(
|
||||
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
|
||||
)
|
||||
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
|
||||
|
||||
# Initialize citation processor
|
||||
citation_processor = CitationProcessor(
|
||||
context_docs=llm_docs,
|
||||
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
# Process the final answer through citation processor
|
||||
collected_citations: list = []
|
||||
for response_part in citation_processor.process_token(final_answer):
|
||||
if hasattr(response_part, "citation_num"): # It's a CitationInfo
|
||||
collected_citations.append(response_part)
|
||||
|
||||
# Emit citation events if we found any citations
|
||||
if collected_citations:
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
|
||||
dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
|
||||
ctx.current_run_step = index
|
||||
|
||||
|
||||
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
|
||||
if isinstance(ev, RawResponsesStreamEvent):
|
||||
# TODO: might need some variation here for different types of models
|
||||
# OpenAI packet translator
|
||||
obj: PacketObj | None = None
|
||||
if ev.data.type == "response.content_part.added":
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
ctx.aggregated_context.cited_documents
|
||||
)
|
||||
obj = MessageStart(
|
||||
type="message_start", content="", final_documents=retrieved_search_docs
|
||||
)
|
||||
elif ev.data.type == "response.output_text.delta":
|
||||
obj = MessageDelta(type="message_delta", content=ev.data.delta)
|
||||
elif ev.data.type == "response.content_part.done":
|
||||
obj = SectionEnd(type="section_end")
|
||||
return obj
|
||||
return None
|
||||
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
1
backend/onyx/chat/turn/infra/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Infrastructure module for chat turn orchestration
|
||||
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
57
backend/onyx/chat/turn/infra/chat_turn_event_stream.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnDependencies
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
|
||||
def unified_event_stream(
|
||||
turn_func: Callable[..., None],
|
||||
) -> Callable[..., Generator[Packet, None]]:
|
||||
"""
|
||||
Decorator that wraps a turn_func to provide event streaming capabilities.
|
||||
|
||||
Usage:
|
||||
@unified_event_stream
|
||||
def my_turn_func(messages, dependencies, *args, **kwargs):
|
||||
# Your turn logic here
|
||||
pass
|
||||
|
||||
Then call it like:
|
||||
generator = my_turn_func(messages, dependencies, *args, **kwargs)
|
||||
"""
|
||||
|
||||
def wrapper(
|
||||
messages: List[Dict[str, Any]],
|
||||
dependencies: ChatTurnDependencies,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> Generator[Packet, None]:
|
||||
def run_with_exception_capture() -> None:
|
||||
try:
|
||||
turn_func(messages, dependencies, *args, **kwargs)
|
||||
except Exception as e:
|
||||
dependencies.emitter.emit(
|
||||
Packet(ind=0, obj=PacketException(type="error", exception=e))
|
||||
)
|
||||
|
||||
thread = run_in_background(run_with_exception_capture)
|
||||
while True:
|
||||
pkt: Packet = dependencies.emitter.bus.get()
|
||||
if pkt.obj == OverallStop(type="stop"):
|
||||
yield pkt
|
||||
break
|
||||
elif isinstance(pkt.obj, PacketException):
|
||||
raise pkt.obj.exception
|
||||
else:
|
||||
yield pkt
|
||||
wait_on_background(thread)
|
||||
|
||||
return wrapper
|
||||
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
21
backend/onyx/chat/turn/infra/emitter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from queue import Queue
|
||||
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
class Emitter:
|
||||
"""Use this inside tools to emit arbitrary UI progress."""
|
||||
|
||||
def __init__(self, bus: Queue):
|
||||
self.bus = bus
|
||||
self.packet_history: list[Packet] = []
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
self.bus.put(packet)
|
||||
self.packet_history.append(packet)
|
||||
|
||||
|
||||
def get_default_emitter() -> Emitter:
|
||||
bus: Queue[Packet] = Queue()
|
||||
emitter = Emitter(bus)
|
||||
return emitter
|
||||
170
backend/onyx/chat/turn/infra/session_sink.py
Normal file
170
backend/onyx/chat/turn/infra/session_sink.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# TODO: Figure out a way to persist information is robust to cancellation,
|
||||
# modular so easily testable in unit tests and evals [likely injecting some higher
|
||||
# level session manager and span sink], potentially has some robustness off the critical path,
|
||||
# and promotes clean separation of concerns.
|
||||
import re
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
|
||||
|
||||
def save_iteration(
|
||||
db_session: Session,
|
||||
message_id: int,
|
||||
chat_session_id: UUID,
|
||||
research_type: ResearchType,
|
||||
ctx: ChatTurnContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
) -> None:
|
||||
# first, insert the search_docs
|
||||
is_internet_marker_dict: dict[str, bool] = {}
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=ctx.run_dependencies.llm.config.model_name,
|
||||
provider_type=ctx.run_dependencies.llm.config.model_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=chat_session_id,
|
||||
is_agentic=research_type == ResearchType.DEEP,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan={},
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
# TODO: I don't think this is the ideal schema for all use cases
|
||||
# find a better schema to store tool and reasoning calls
|
||||
for iteration_preparation in ctx.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
is_web_fetch=iteration_answer.is_web_fetch,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def extract_final_answer_from_packets(packet_history: list[Packet]) -> str:
|
||||
"""Extract the final answer by concatenating all MessageDelta content."""
|
||||
final_answer = ""
|
||||
for packet in packet_history:
|
||||
if isinstance(packet.obj, MessageDelta) or isinstance(packet.obj, MessageStart):
|
||||
final_answer += packet.obj.content
|
||||
return final_answer
|
||||
177
backend/onyx/chat/turn/infra/sync_agent_stream_adapter.py
Normal file
177
backend/onyx/chat/turn/infra/sync_agent_stream_adapter.py
Normal file
@@ -0,0 +1,177 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import Agent
|
||||
from agents import RunResultStreaming
|
||||
from agents.run import Runner
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SyncAgentStream(Generic[T]):
|
||||
"""
|
||||
Convert an async streamed run into a sync iterator with cooperative cancellation.
|
||||
Runs the Agent in a background thread.
|
||||
|
||||
Usage:
|
||||
adapter = SyncStreamAdapter(
|
||||
agent=agent,
|
||||
input=input,
|
||||
context=context,
|
||||
max_turns=100,
|
||||
queue_maxsize=0, # optional backpressure
|
||||
)
|
||||
for ev in adapter: # sync iteration
|
||||
...
|
||||
# or cancel from elsewhere:
|
||||
adapter.cancel()
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Agent,
|
||||
input: list[dict],
|
||||
context: ChatTurnContext,
|
||||
max_turns: int = 100,
|
||||
queue_maxsize: int = 0,
|
||||
) -> None:
|
||||
self._agent = agent
|
||||
self._input = input
|
||||
self._context = context
|
||||
self._max_turns = max_turns
|
||||
|
||||
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._streamed: RunResultStreaming | None = None
|
||||
self._exc: Optional[BaseException] = None
|
||||
self._cancel_requested = threading.Event()
|
||||
self._started = threading.Event()
|
||||
self._done = threading.Event()
|
||||
|
||||
self._start_thread()
|
||||
|
||||
# ---------- public sync API ----------
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
try:
|
||||
while True:
|
||||
item = self._q.get()
|
||||
if item is self._SENTINEL:
|
||||
# If the consumer thread raised, surface it now
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
# Normal completion
|
||||
return
|
||||
yield item # type: ignore[misc,return-value]
|
||||
finally:
|
||||
# Ensure we fully clean up whether we exited due to exception,
|
||||
# StopIteration, or external cancel.
|
||||
self.close()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Cooperatively cancel the underlying streamed run and shut down.
|
||||
Safe to call multiple times and from any thread.
|
||||
"""
|
||||
self._cancel_requested.set()
|
||||
loop = self._loop
|
||||
streamed = self._streamed
|
||||
if loop is not None and streamed is not None and not self._done.is_set():
|
||||
loop.call_soon_threadsafe(streamed.cancel)
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self, *, wait: bool = True) -> None:
|
||||
"""Idempotent shutdown."""
|
||||
self.cancel()
|
||||
# ask the loop to stop if it's still running
|
||||
loop = self._loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
except Exception:
|
||||
pass
|
||||
# join the thread
|
||||
if wait and self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
t = run_in_background(self._thread_main)
|
||||
self._thread = t
|
||||
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
|
||||
self._started.wait(timeout=1.0)
|
||||
|
||||
def _thread_main(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
# Start the streamed run inside the loop thread
|
||||
self._streamed = Runner.run_streamed(
|
||||
self._agent,
|
||||
self._input, # type: ignore[arg-type]
|
||||
context=self._context,
|
||||
max_turns=self._max_turns,
|
||||
)
|
||||
|
||||
# If cancel was requested before we created _streamed, honor it now
|
||||
if self._cancel_requested.is_set():
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
|
||||
# Consume async events and forward into the thread-safe queue
|
||||
async for ev in self._streamed.stream_events():
|
||||
# Early exit if a late cancel arrives
|
||||
if self._cancel_requested.is_set():
|
||||
# Try to cancel gracefully; don't break until cancel takes effect
|
||||
try:
|
||||
await self._streamed.cancel() # type: ignore[func-returns-value]
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
# This put() may block if queue_maxsize > 0 (backpressure)
|
||||
self._q.put(ev)
|
||||
|
||||
except BaseException as e:
|
||||
# Save exception to surface on the sync iterator side
|
||||
self._exc = e
|
||||
finally:
|
||||
# Signal end-of-stream
|
||||
self._q.put(self._SENTINEL)
|
||||
self._done.set()
|
||||
|
||||
# Mark started and run the worker to completion
|
||||
self._started.set()
|
||||
try:
|
||||
loop.run_until_complete(worker())
|
||||
finally:
|
||||
try:
|
||||
# Drain pending tasks/callbacks safely
|
||||
pending = asyncio.all_tasks(loop=loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
self._loop = None
|
||||
self._streamed = None
|
||||
70
backend/onyx/chat/turn/models.py
Normal file
70
backend/onyx/chat/turn/models.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import dataclasses
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from agents import CodeInterpreterTool
|
||||
from agents import ComputerTool
|
||||
from agents import FileSearchTool
|
||||
from agents import FunctionTool
|
||||
from agents import HostedMCPTool
|
||||
from agents import ImageGenerationTool as AgentsImageGenerationTool
|
||||
from agents import LocalShellTool
|
||||
from agents import Model
|
||||
from agents import WebSearchTool
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.turn.infra.emitter import Emitter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
# Type alias for all tool types accepted by the Agent
|
||||
AgentToolType = (
|
||||
FunctionTool
|
||||
| FileSearchTool
|
||||
| WebSearchTool
|
||||
| ComputerTool
|
||||
| HostedMCPTool
|
||||
| LocalShellTool
|
||||
| AgentsImageGenerationTool
|
||||
| CodeInterpreterTool
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTurnDependencies:
|
||||
llm_model: Model
|
||||
llm: LLM
|
||||
db_session: Session
|
||||
tools: Sequence[FunctionTool]
|
||||
redis_client: Redis
|
||||
emitter: Emitter
|
||||
search_pipeline: SearchTool | None = None
|
||||
image_generation_tool: ImageGenerationTool | None = None
|
||||
okta_profile_tool: OktaProfileTool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatTurnContext:
|
||||
"""Context class to hold search tool and other dependencies"""
|
||||
|
||||
chat_session_id: UUID
|
||||
message_id: int
|
||||
research_type: ResearchType
|
||||
run_dependencies: ChatTurnDependencies
|
||||
aggregated_context: AggregatedDRContext
|
||||
current_run_step: int = 0
|
||||
iteration_instructions: list[IterationInstructions] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
web_fetch_results: list[dict] = dataclasses.field(default_factory=list)
|
||||
@@ -349,10 +349,6 @@ except ValueError:
|
||||
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
|
||||
)
|
||||
@@ -360,18 +356,30 @@ CELERY_WORKER_PRIMARY_CONCURRENCY = int(
|
||||
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
|
||||
try:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get(
|
||||
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
CELERY_WORKER_MONITORING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
|
||||
)
|
||||
|
||||
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
|
||||
VESPA_SYNC_MAX_TASKS = 8192
|
||||
|
||||
@@ -72,12 +72,13 @@ POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
|
||||
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
"celery_worker_user_file_processing"
|
||||
)
|
||||
|
||||
@@ -61,21 +61,16 @@ _BASE_EMBEDDING_MODELS = [
|
||||
dim=1536,
|
||||
index_name="danswer_chunk_text_embedding_3_small",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/gemini-embedding-001",
|
||||
dim=3072,
|
||||
index_name="danswer_chunk_google_gemini_embedding_001",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/text-embedding-005",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_text_embedding_005",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_google_textembedding_gecko_003",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="google/textembedding-gecko@003",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_textembedding_gecko_003",
|
||||
),
|
||||
_BaseEmbeddingModel(
|
||||
name="voyage/voyage-large-2-instruct",
|
||||
dim=1024,
|
||||
|
||||
@@ -344,6 +344,9 @@ class GoogleDriveConnector(
|
||||
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
|
||||
logger.info(
|
||||
f"Getting all drives for user {user_email} with service account: {is_service_account}"
|
||||
)
|
||||
all_drive_ids: set[str] = set()
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.drives().list,
|
||||
|
||||
@@ -423,6 +423,35 @@ class SavedSearchDoc(SearchDoc):
|
||||
"""Create SavedSearchDoc from serialized dictionary data (e.g., from database JSON)"""
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> "SavedSearchDoc":
|
||||
"""Create a SavedSearchDoc from a URL for internet search documents.
|
||||
|
||||
Uses the INTERNET_SEARCH_DOC_ prefix for document_id to match the format
|
||||
used by inference sections created from internet content.
|
||||
"""
|
||||
return cls(
|
||||
# db_doc_id can be a filler value since these docs are not saved to the database.
|
||||
db_doc_id=0,
|
||||
document_id="INTERNET_SEARCH_DOC_" + url,
|
||||
chunk_ind=0,
|
||||
semantic_identifier=url,
|
||||
link=url,
|
||||
blurb="",
|
||||
source_type=DocumentSource.WEB,
|
||||
boost=1,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
score=0.0,
|
||||
is_relevant=None,
|
||||
relevance_explanation=None,
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
primary_owners=None,
|
||||
secondary_owners=None,
|
||||
is_internet=True,
|
||||
)
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, SavedSearchDoc):
|
||||
return NotImplemented
|
||||
|
||||
@@ -970,6 +970,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.search_docs, remove_doc_content=remove_doc_content
|
||||
),
|
||||
message_type=chat_message.message_type,
|
||||
research_type=chat_message.research_type,
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.entities import delete_from_kg_entities__no_commit
|
||||
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -55,7 +54,6 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -304,80 +302,27 @@ def get_document_counts_for_cc_pairs(
|
||||
]
|
||||
|
||||
|
||||
# For use with our thread-level parallelism utils. Note that any relationships
|
||||
# you wish to use MUST be eagerly loaded, as the session will not be available
|
||||
# after this function to allow lazy loading.
|
||||
def get_document_counts_for_cc_pairs_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
def get_document_counts_for_all_cc_pairs(
|
||||
db_session: Session,
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
|
||||
"""Return (connector_id, credential_id, count) for ALL CC pairs with indexed docs.
|
||||
|
||||
|
||||
def _get_document_counts_for_cc_pairs_batch(
|
||||
batch: list[tuple[int, int]],
|
||||
) -> list[tuple[int, int, int]]:
|
||||
"""Worker for parallel execution: opens its own session per batch."""
|
||||
if not batch:
|
||||
return []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
func.count(),
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
tuple_(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
).in_(batch),
|
||||
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
|
||||
)
|
||||
)
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_document_counts_for_cc_pairs_batched_parallel(
|
||||
cc_pairs: list[ConnectorCredentialPairIdentifier],
|
||||
batch_size: int = 1000,
|
||||
max_workers: int | None = None,
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
"""Parallel variant that batches the IN-clause and runs batches concurrently.
|
||||
|
||||
Opens an isolated DB session per batch to avoid sharing a session across threads.
|
||||
Executes a single grouped query so Postgres can fully leverage indexes,
|
||||
avoiding large batched IN-lists.
|
||||
"""
|
||||
if not cc_pairs:
|
||||
return []
|
||||
|
||||
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
|
||||
|
||||
batches: list[list[tuple[int, int]]] = [
|
||||
cc_ids[i : i + batch_size] for i in range(0, len(cc_ids), batch_size)
|
||||
]
|
||||
|
||||
funcs = [(_get_document_counts_for_cc_pairs_batch, (batch,)) for batch in batches]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
functions_with_args=funcs, max_workers=max_workers
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
func.count(),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.has_been_indexed.is_(True))
|
||||
.group_by(
|
||||
DocumentByConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
aggregated_counts: dict[tuple[int, int], int] = {}
|
||||
for batch_result in results:
|
||||
if not batch_result:
|
||||
continue
|
||||
for connector_id, credential_id, cnt in batch_result:
|
||||
aggregated_counts[(connector_id, credential_id)] = cnt
|
||||
|
||||
return [
|
||||
(connector_id, credential_id, cnt)
|
||||
for (connector_id, credential_id), cnt in aggregated_counts.items()
|
||||
]
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_access_info_for_document(
|
||||
|
||||
@@ -18,12 +18,14 @@ from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.federated import create_federated_connector_document_set_mapping
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet as DocumentSetDBModel
|
||||
from onyx.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
@@ -370,10 +372,6 @@ def update_document_set(
|
||||
db_session.add_all(ds_cc_pairs)
|
||||
|
||||
# Update federated connector mappings
|
||||
from onyx.db.federated import create_federated_connector_document_set_mapping
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from sqlalchemy import delete
|
||||
|
||||
# Delete existing federated connector mappings for this document set
|
||||
delete_stmt = delete(FederatedConnector__DocumentSet).where(
|
||||
FederatedConnector__DocumentSet.document_set_id == document_set_row.id
|
||||
@@ -455,6 +453,13 @@ def mark_document_set_as_to_be_deleted(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
)
|
||||
|
||||
# delete all federated connector mappings so the cleanup task can fully
|
||||
# remove the document set once the Vespa sync completes
|
||||
delete_stmt = delete(FederatedConnector__DocumentSet).where(
|
||||
FederatedConnector__DocumentSet.document_set_id == document_set_id
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# delete all private document set information
|
||||
versioned_delete_private_fn = fetch_versioned_implementation(
|
||||
"onyx.db.document_set", "delete_document_set_privacy__no_commit"
|
||||
|
||||
@@ -25,6 +25,32 @@ class IndexingStatus(str, PyEnum):
|
||||
)
|
||||
|
||||
|
||||
class PermissionSyncStatus(str, PyEnum):
|
||||
"""Status enum for permission sync attempts"""
|
||||
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
COMPLETED_WITH_ERRORS = "completed_with_errors"
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
terminal_states = {
|
||||
PermissionSyncStatus.SUCCESS,
|
||||
PermissionSyncStatus.COMPLETED_WITH_ERRORS,
|
||||
PermissionSyncStatus.CANCELED,
|
||||
PermissionSyncStatus.FAILED,
|
||||
}
|
||||
return self in terminal_states
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
return (
|
||||
self == PermissionSyncStatus.SUCCESS
|
||||
or self == PermissionSyncStatus.COMPLETED_WITH_ERRORS
|
||||
)
|
||||
|
||||
|
||||
class IndexingMode(str, PyEnum):
|
||||
UPDATE = "update"
|
||||
REINDEX = "reindex"
|
||||
|
||||
@@ -74,6 +74,7 @@ from onyx.db.enums import ChatSessionSharedStatus
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import PermissionSyncStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.pydantic_type import PydanticListType, PydanticType
|
||||
from onyx.kg.models import KGEntityTypeAttributes
|
||||
@@ -2473,6 +2474,7 @@ class Tool(Base):
|
||||
mcp_server_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
@@ -3435,6 +3437,8 @@ class ResearchAgentIterationSubStep(Base):
|
||||
# for search-based step-types
|
||||
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
claims: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
is_web_fetch: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
queries: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
# for image generation step-types
|
||||
generated_images: Mapped[GeneratedImageFullResult | None] = mapped_column(
|
||||
@@ -3599,3 +3603,145 @@ class MCPConnectionConfig(Base):
|
||||
Index("ix_mcp_connection_config_user_email", "user_email"),
|
||||
Index("ix_mcp_connection_config_server_user", "mcp_server_id", "user_email"),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Permission Sync Tables
|
||||
"""
|
||||
|
||||
|
||||
class DocPermissionSyncAttempt(Base):
|
||||
"""
|
||||
Represents an attempt to sync document permissions for a connector credential pair.
|
||||
Similar to IndexAttempt but specifically for document permission syncing operations.
|
||||
"""
|
||||
|
||||
__tablename__ = "doc_permission_sync_attempt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Status of the sync attempt
|
||||
status: Mapped[PermissionSyncStatus] = mapped_column(
|
||||
Enum(PermissionSyncStatus, native_enum=False, index=True)
|
||||
)
|
||||
|
||||
# Counts for tracking progress
|
||||
total_docs_synced: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
docs_with_permission_errors: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
|
||||
# Error message if sync fails
|
||||
error_message: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
|
||||
# Timestamps
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
index=True,
|
||||
)
|
||||
time_started: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
time_finished: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
|
||||
# Relationships
|
||||
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_permission_sync_attempt_latest_for_cc_pair",
|
||||
"connector_credential_pair_id",
|
||||
"time_created",
|
||||
),
|
||||
Index(
|
||||
"ix_permission_sync_attempt_status_time",
|
||||
"status",
|
||||
desc("time_finished"),
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<DocPermissionSyncAttempt(id={self.id!r}, " f"status={self.status!r})>"
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
|
||||
class ExternalGroupPermissionSyncAttempt(Base):
|
||||
"""
|
||||
Represents an attempt to sync external group memberships for users.
|
||||
This tracks the syncing of user-to-external-group mappings across connectors.
|
||||
"""
|
||||
|
||||
__tablename__ = "external_group_permission_sync_attempt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Can be tied to a specific connector or be a global group sync
|
||||
connector_credential_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=True, # Nullable for global group syncs across all connectors
|
||||
)
|
||||
|
||||
# Status of the group sync attempt
|
||||
status: Mapped[PermissionSyncStatus] = mapped_column(
|
||||
Enum(PermissionSyncStatus, native_enum=False, index=True)
|
||||
)
|
||||
|
||||
# Counts for tracking progress
|
||||
total_users_processed: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
total_groups_processed: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
total_group_memberships_synced: Mapped[int | None] = mapped_column(
|
||||
Integer, default=0
|
||||
)
|
||||
|
||||
# Error message if sync fails
|
||||
error_message: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
|
||||
# Timestamps
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
index=True,
|
||||
)
|
||||
time_started: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
time_finished: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
|
||||
# Relationships
|
||||
connector_credential_pair: Mapped[ConnectorCredentialPair | None] = relationship(
|
||||
"ConnectorCredentialPair"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_group_sync_attempt_cc_pair_time",
|
||||
"connector_credential_pair_id",
|
||||
"time_created",
|
||||
),
|
||||
Index(
|
||||
"ix_group_sync_attempt_status_time",
|
||||
"status",
|
||||
desc("time_finished"),
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<ExternalGroupPermissionSyncAttempt(id={self.id!r}, "
|
||||
f"status={self.status!r})>"
|
||||
)
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
485
backend/onyx/db/permission_sync_attempt.py
Normal file
485
backend/onyx/db/permission_sync_attempt.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Permission sync attempt CRUD operations and utilities.
|
||||
|
||||
This module contains all CRUD operations for both DocPermissionSyncAttempt
|
||||
and ExternalGroupPermissionSyncAttempt models, along with shared utilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine.cursor import CursorResult
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import PermissionSyncStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocPermissionSyncAttempt
|
||||
from onyx.db.models import ExternalGroupPermissionSyncAttempt
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DOC PERMISSION SYNC ATTEMPT CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_doc_permission_sync_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Create a new doc permission sync attempt.
|
||||
|
||||
Args:
|
||||
connector_credential_pair_id: The ID of the connector credential pair
|
||||
db_session: The database session
|
||||
|
||||
Returns:
|
||||
The ID of the created attempt
|
||||
"""
|
||||
attempt = DocPermissionSyncAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
status=PermissionSyncStatus.NOT_STARTED,
|
||||
)
|
||||
db_session.add(attempt)
|
||||
db_session.commit()
|
||||
|
||||
return attempt.id
|
||||
|
||||
|
||||
def get_doc_permission_sync_attempt(
|
||||
db_session: Session,
|
||||
attempt_id: int,
|
||||
eager_load_connector: bool = False,
|
||||
) -> DocPermissionSyncAttempt | None:
|
||||
"""Get a doc permission sync attempt by ID.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
attempt_id: The ID of the attempt
|
||||
eager_load_connector: If True, eagerly loads the connector and cc_pair relationships
|
||||
|
||||
Returns:
|
||||
The attempt if found, None otherwise
|
||||
"""
|
||||
stmt = select(DocPermissionSyncAttempt).where(
|
||||
DocPermissionSyncAttempt.id == attempt_id
|
||||
)
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(
|
||||
joinedload(DocPermissionSyncAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
)
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def get_recent_doc_permission_sync_attempts_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[DocPermissionSyncAttempt]:
|
||||
"""Get recent doc permission sync attempts for a cc pair, most recent first."""
|
||||
return list(
|
||||
db_session.execute(
|
||||
select(DocPermissionSyncAttempt)
|
||||
.where(DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id)
|
||||
.order_by(DocPermissionSyncAttempt.time_created.desc())
|
||||
.limit(limit)
|
||||
).scalars()
|
||||
)
|
||||
|
||||
|
||||
def mark_doc_permission_sync_attempt_in_progress(
|
||||
attempt_id: int,
|
||||
db_session: Session,
|
||||
) -> DocPermissionSyncAttempt:
|
||||
"""Mark a doc permission sync attempt as IN_PROGRESS.
|
||||
Locks the row during update."""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(DocPermissionSyncAttempt)
|
||||
.where(DocPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if attempt.status != PermissionSyncStatus.NOT_STARTED:
|
||||
raise RuntimeError(
|
||||
f"Doc permission sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. "
|
||||
f"Current status is '{attempt.status}'."
|
||||
)
|
||||
|
||||
attempt.status = PermissionSyncStatus.IN_PROGRESS
|
||||
attempt.time_started = func.now() # type: ignore
|
||||
db_session.commit()
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception("mark_doc_permission_sync_attempt_in_progress exceptioned.")
|
||||
raise
|
||||
|
||||
|
||||
def mark_doc_permission_sync_attempt_failed(
|
||||
attempt_id: int,
|
||||
db_session: Session,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""Mark a doc permission sync attempt as failed."""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(DocPermissionSyncAttempt)
|
||||
.where(DocPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if not attempt.time_started:
|
||||
attempt.time_started = func.now() # type: ignore
|
||||
attempt.status = PermissionSyncStatus.FAILED
|
||||
attempt.time_finished = func.now() # type: ignore
|
||||
attempt.error_message = error_message
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for permission sync attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={
|
||||
"doc_permission_sync_attempt_id": attempt_id,
|
||||
"status": PermissionSyncStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def complete_doc_permission_sync_attempt(
|
||||
db_session: Session,
|
||||
attempt_id: int,
|
||||
total_docs_synced: int,
|
||||
docs_with_permission_errors: int,
|
||||
) -> DocPermissionSyncAttempt:
|
||||
"""Complete a doc permission sync attempt by updating progress and setting final status.
|
||||
|
||||
This combines the progress update and final status marking into a single operation.
|
||||
If there were permission errors, the attempt is marked as COMPLETED_WITH_ERRORS,
|
||||
otherwise it's marked as SUCCESS.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
attempt_id: The ID of the attempt
|
||||
total_docs_synced: Total number of documents synced
|
||||
docs_with_permission_errors: Number of documents that had permission errors
|
||||
|
||||
Returns:
|
||||
The completed attempt
|
||||
"""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(DocPermissionSyncAttempt)
|
||||
.where(DocPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
# Update progress counters
|
||||
attempt.total_docs_synced = (attempt.total_docs_synced or 0) + total_docs_synced
|
||||
attempt.docs_with_permission_errors = (
|
||||
attempt.docs_with_permission_errors or 0
|
||||
) + docs_with_permission_errors
|
||||
|
||||
# Set final status based on whether there were errors
|
||||
if docs_with_permission_errors > 0:
|
||||
attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS
|
||||
else:
|
||||
attempt.status = PermissionSyncStatus.SUCCESS
|
||||
|
||||
attempt.time_finished = func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={
|
||||
"doc_permission_sync_attempt_id": attempt_id,
|
||||
"status": attempt.status.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception("complete_doc_permission_sync_attempt exceptioned.")
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# EXTERNAL GROUP PERMISSION SYNC ATTEMPT CRUD
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_external_group_sync_attempt(
|
||||
connector_credential_pair_id: int | None,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Create a new external group sync attempt.
|
||||
|
||||
Args:
|
||||
connector_credential_pair_id: The ID of the connector credential pair, or None for global syncs
|
||||
db_session: The database session
|
||||
|
||||
Returns:
|
||||
The ID of the created attempt
|
||||
"""
|
||||
attempt = ExternalGroupPermissionSyncAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
status=PermissionSyncStatus.NOT_STARTED,
|
||||
)
|
||||
db_session.add(attempt)
|
||||
db_session.commit()
|
||||
|
||||
return attempt.id
|
||||
|
||||
|
||||
def get_external_group_sync_attempt(
|
||||
db_session: Session,
|
||||
attempt_id: int,
|
||||
eager_load_connector: bool = False,
|
||||
) -> ExternalGroupPermissionSyncAttempt | None:
|
||||
"""Get an external group sync attempt by ID.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
attempt_id: The ID of the attempt
|
||||
eager_load_connector: If True, eagerly loads the connector and cc_pair relationships
|
||||
|
||||
Returns:
|
||||
The attempt if found, None otherwise
|
||||
"""
|
||||
stmt = select(ExternalGroupPermissionSyncAttempt).where(
|
||||
ExternalGroupPermissionSyncAttempt.id == attempt_id
|
||||
)
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(
|
||||
joinedload(
|
||||
ExternalGroupPermissionSyncAttempt.connector_credential_pair
|
||||
).joinedload(ConnectorCredentialPair.connector)
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def get_recent_external_group_sync_attempts_for_cc_pair(
|
||||
cc_pair_id: int | None,
|
||||
limit: int,
|
||||
db_session: Session,
|
||||
) -> list[ExternalGroupPermissionSyncAttempt]:
|
||||
"""Get recent external group sync attempts for a cc pair, most recent first.
|
||||
If cc_pair_id is None, gets global group sync attempts."""
|
||||
stmt = select(ExternalGroupPermissionSyncAttempt)
|
||||
|
||||
if cc_pair_id is not None:
|
||||
stmt = stmt.where(
|
||||
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id
|
||||
== cc_pair_id
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(
|
||||
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id.is_(None)
|
||||
)
|
||||
|
||||
return list(
|
||||
db_session.execute(
|
||||
stmt.order_by(ExternalGroupPermissionSyncAttempt.time_created.desc()).limit(
|
||||
limit
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
|
||||
|
||||
def mark_external_group_sync_attempt_in_progress(
|
||||
attempt_id: int,
|
||||
db_session: Session,
|
||||
) -> ExternalGroupPermissionSyncAttempt:
|
||||
"""Mark an external group sync attempt as IN_PROGRESS.
|
||||
Locks the row during update."""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(ExternalGroupPermissionSyncAttempt)
|
||||
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if attempt.status != PermissionSyncStatus.NOT_STARTED:
|
||||
raise RuntimeError(
|
||||
f"External group sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. "
|
||||
f"Current status is '{attempt.status}'."
|
||||
)
|
||||
|
||||
attempt.status = PermissionSyncStatus.IN_PROGRESS
|
||||
attempt.time_started = func.now() # type: ignore
|
||||
db_session.commit()
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception("mark_external_group_sync_attempt_in_progress exceptioned.")
|
||||
raise
|
||||
|
||||
|
||||
def mark_external_group_sync_attempt_failed(
|
||||
attempt_id: int,
|
||||
db_session: Session,
|
||||
error_message: str,
|
||||
) -> None:
|
||||
"""Mark an external group sync attempt as failed."""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(ExternalGroupPermissionSyncAttempt)
|
||||
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if not attempt.time_started:
|
||||
attempt.time_started = func.now() # type: ignore
|
||||
attempt.status = PermissionSyncStatus.FAILED
|
||||
attempt.time_finished = func.now() # type: ignore
|
||||
attempt.error_message = error_message
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry for permission sync attempt status change
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={
|
||||
"external_group_sync_attempt_id": attempt_id,
|
||||
"status": PermissionSyncStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def complete_external_group_sync_attempt(
|
||||
db_session: Session,
|
||||
attempt_id: int,
|
||||
total_users_processed: int,
|
||||
total_groups_processed: int,
|
||||
total_group_memberships_synced: int,
|
||||
errors_encountered: int = 0,
|
||||
) -> ExternalGroupPermissionSyncAttempt:
|
||||
"""Complete an external group sync attempt by updating progress and setting final status.
|
||||
|
||||
This combines the progress update and final status marking into a single operation.
|
||||
If there were errors, the attempt is marked as COMPLETED_WITH_ERRORS,
|
||||
otherwise it's marked as SUCCESS.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
attempt_id: The ID of the attempt
|
||||
total_users_processed: Total users processed
|
||||
total_groups_processed: Total groups processed
|
||||
total_group_memberships_synced: Total group memberships synced
|
||||
errors_encountered: Number of errors encountered (determines if COMPLETED_WITH_ERRORS)
|
||||
|
||||
Returns:
|
||||
The completed attempt
|
||||
"""
|
||||
try:
|
||||
attempt = db_session.execute(
|
||||
select(ExternalGroupPermissionSyncAttempt)
|
||||
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
# Update progress counters
|
||||
attempt.total_users_processed = (
|
||||
attempt.total_users_processed or 0
|
||||
) + total_users_processed
|
||||
attempt.total_groups_processed = (
|
||||
attempt.total_groups_processed or 0
|
||||
) + total_groups_processed
|
||||
attempt.total_group_memberships_synced = (
|
||||
attempt.total_group_memberships_synced or 0
|
||||
) + total_group_memberships_synced
|
||||
|
||||
# Set final status based on whether there were errors
|
||||
if errors_encountered > 0:
|
||||
attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS
|
||||
else:
|
||||
attempt.status = PermissionSyncStatus.SUCCESS
|
||||
|
||||
attempt.time_finished = func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
# Add telemetry
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={
|
||||
"external_group_sync_attempt_id": attempt_id,
|
||||
"status": attempt.status.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
},
|
||||
)
|
||||
return attempt
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
logger.exception("complete_external_group_sync_attempt exceptioned.")
|
||||
raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DELETION FUNCTIONS
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def delete_doc_permission_sync_attempts__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> int:
|
||||
"""Delete all doc permission sync attempts for a connector credential pair.
|
||||
|
||||
This does not commit the transaction. It should be used within an existing transaction.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
cc_pair_id: The connector credential pair ID
|
||||
|
||||
Returns:
|
||||
The number of attempts deleted
|
||||
"""
|
||||
stmt = delete(DocPermissionSyncAttempt).where(
|
||||
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
result = cast(CursorResult[Any], db_session.execute(stmt))
|
||||
return result.rowcount or 0
|
||||
|
||||
|
||||
def delete_external_group_permission_sync_attempts__no_commit(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
) -> int:
|
||||
"""Delete all external group permission sync attempts for a connector credential pair.
|
||||
|
||||
This does not commit the transaction. It should be used within an existing transaction.
|
||||
|
||||
Args:
|
||||
db_session: The database session
|
||||
cc_pair_id: The connector credential pair ID
|
||||
|
||||
Returns:
|
||||
The number of attempts deleted
|
||||
"""
|
||||
stmt = delete(ExternalGroupPermissionSyncAttempt).where(
|
||||
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id
|
||||
)
|
||||
result = cast(CursorResult[Any], db_session.execute(stmt))
|
||||
return result.rowcount or 0
|
||||
@@ -19,16 +19,23 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tools(db_session: Session) -> list[Tool]:
|
||||
return list(db_session.scalars(select(Tool)).all())
|
||||
def get_tools(db_session: Session, *, only_enabled: bool = False) -> list[Tool]:
|
||||
query = select(Tool)
|
||||
if only_enabled:
|
||||
query = query.where(Tool.enabled.is_(True))
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def get_tools_by_mcp_server_id(mcp_server_id: int, db_session: Session) -> list[Tool]:
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(Tool).where(Tool.mcp_server_id == mcp_server_id)
|
||||
).all()
|
||||
)
|
||||
def get_tools_by_mcp_server_id(
|
||||
mcp_server_id: int,
|
||||
db_session: Session,
|
||||
*,
|
||||
only_enabled: bool = False,
|
||||
) -> list[Tool]:
|
||||
query = select(Tool).where(Tool.mcp_server_id == mcp_server_id)
|
||||
if only_enabled:
|
||||
query = query.where(Tool.enabled.is_(True))
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
@@ -53,6 +60,9 @@ def create_tool__no_commit(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool,
|
||||
*,
|
||||
mcp_server_id: int | None = None,
|
||||
enabled: bool = True,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
@@ -64,6 +74,8 @@ def create_tool__no_commit(
|
||||
),
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
mcp_server_id=mcp_server_id,
|
||||
enabled=enabled,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
|
||||
@@ -1062,7 +1062,7 @@ class VespaIndex(DocumentIndex):
|
||||
*,
|
||||
tenant_id: str,
|
||||
index_name: str,
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""
|
||||
Deletes all entries in the specified index with the given tenant_id.
|
||||
|
||||
@@ -1072,6 +1072,9 @@ class VespaIndex(DocumentIndex):
|
||||
Parameters:
|
||||
tenant_id (str): The tenant ID whose documents are to be deleted.
|
||||
index_name (str): The name of the index from which to delete documents.
|
||||
|
||||
Returns:
|
||||
int: The number of documents deleted.
|
||||
"""
|
||||
logger.info(
|
||||
f"Deleting entries with tenant_id: {tenant_id} from index: {index_name}"
|
||||
@@ -1084,7 +1087,7 @@ class VespaIndex(DocumentIndex):
|
||||
logger.info(
|
||||
f"No documents found with tenant_id: {tenant_id} in index: {index_name}"
|
||||
)
|
||||
return
|
||||
return 0
|
||||
|
||||
# Step 2: Delete documents in batches
|
||||
delete_requests = [
|
||||
@@ -1093,6 +1096,7 @@ class VespaIndex(DocumentIndex):
|
||||
]
|
||||
|
||||
cls._apply_deletes_batched(delete_requests)
|
||||
return len(document_ids)
|
||||
|
||||
@classmethod
|
||||
def _get_all_document_ids_by_tenant_id(
|
||||
|
||||
@@ -14,7 +14,7 @@ The evaluation system uses [Braintrust](https://www.braintrust.dev/) to run auto
|
||||
|
||||
Kick off a remote job
|
||||
```bash
|
||||
onyx/backend$ python onyx/evals/eval_cli.py --remote --api-key <SUPER_CLOUD_USER_API_KEY> --search-permissions-email <email account to reference> --remote --remote-dataset-name Simple
|
||||
onyx/backend$ python -m dotenv -f .vscode/.env run -- python onyx/evals/eval_cli.py --remote --api-key <SUPER_CLOUD_USER_API_KEY> --search-permissions-email <email account to reference> --remote --remote-dataset-name Simple
|
||||
```
|
||||
|
||||
You can also run the CLI directly from the command line:
|
||||
|
||||
@@ -42,6 +42,7 @@ def run_local(
|
||||
local_data_path: str | None,
|
||||
remote_dataset_name: str | None,
|
||||
search_permissions_email: str | None = None,
|
||||
no_send_logs: bool = False,
|
||||
) -> EvalationAck:
|
||||
"""
|
||||
Run evaluation with local configurations.
|
||||
@@ -63,6 +64,7 @@ def run_local(
|
||||
configuration = EvalConfigurationOptions(
|
||||
search_permissions_email=search_permissions_email,
|
||||
dataset_name=remote_dataset_name or "blank",
|
||||
no_send_logs=no_send_logs,
|
||||
)
|
||||
|
||||
if remote_dataset_name:
|
||||
@@ -172,6 +174,13 @@ def main() -> None:
|
||||
help="Email address to impersonate for the evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-send-logs",
|
||||
action="store_true",
|
||||
help="Do not send logs to the remote server",
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.local_data_path:
|
||||
@@ -215,6 +224,7 @@ def main() -> None:
|
||||
local_data_path=args.local_data_path,
|
||||
remote_dataset_name=args.remote_dataset_name,
|
||||
search_permissions_email=args.search_permissions_email,
|
||||
no_send_logs=args.no_send_logs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ class EvalConfigurationOptions(BaseModel):
|
||||
)
|
||||
search_permissions_email: str
|
||||
dataset_name: str
|
||||
no_send_logs: bool = False
|
||||
|
||||
def get_configuration(self, db_session: Session) -> EvalConfiguration:
|
||||
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from braintrust import Eval
|
||||
from braintrust import EvalCase
|
||||
@@ -23,33 +24,21 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
raise ValueError("Cannot specify both data and remote_dataset_name")
|
||||
if data is None and remote_dataset_name is None:
|
||||
raise ValueError("Must specify either data or remote_dataset_name")
|
||||
|
||||
eval_data: Any = None
|
||||
if remote_dataset_name is not None:
|
||||
eval_data = init_dataset(
|
||||
project=BRAINTRUST_PROJECT, name=remote_dataset_name
|
||||
)
|
||||
Eval(
|
||||
name=BRAINTRUST_PROJECT,
|
||||
data=eval_data,
|
||||
task=task,
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
|
||||
)
|
||||
else:
|
||||
if data is None:
|
||||
raise ValueError(
|
||||
"Must specify data when remote_dataset_name is not specified"
|
||||
)
|
||||
eval_cases: list[EvalCase[dict[str, str], str]] = [
|
||||
EvalCase(input=item["input"]) for item in data
|
||||
]
|
||||
Eval(
|
||||
name=BRAINTRUST_PROJECT,
|
||||
data=eval_cases,
|
||||
task=task,
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
|
||||
)
|
||||
if data:
|
||||
eval_data = [EvalCase(input=item["input"]) for item in data]
|
||||
Eval(
|
||||
name=BRAINTRUST_PROJECT,
|
||||
data=eval_data, # type: ignore[arg-type]
|
||||
task=task,
|
||||
scores=[],
|
||||
metadata={**configuration.model_dump()},
|
||||
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
|
||||
no_send_logs=configuration.no_send_logs,
|
||||
)
|
||||
return EvalationAck(success=True)
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import braintrust
|
||||
from agents import set_trace_processors
|
||||
from braintrust.wrappers.openai import BraintrustTracingProcessor
|
||||
from braintrust_langchain import set_global_handler
|
||||
from braintrust_langchain.callbacks import BraintrustCallbackHandler
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import BRAINTRUST_PROJECT
|
||||
|
||||
MASKING_LENGTH = 20000
|
||||
MASKING_LENGTH = int(os.environ.get("BRAINTRUST_MASKING_LENGTH", "20000"))
|
||||
|
||||
|
||||
def _truncate_str(s: str) -> str:
|
||||
@@ -26,10 +29,11 @@ def _mask(data: Any) -> Any:
|
||||
def setup_braintrust() -> None:
|
||||
"""Initialize Braintrust logger and set up global callback handler."""
|
||||
|
||||
braintrust.init_logger(
|
||||
logger = braintrust.init_logger(
|
||||
project=BRAINTRUST_PROJECT,
|
||||
api_key=BRAINTRUST_API_KEY,
|
||||
)
|
||||
braintrust.set_masking_function(_mask)
|
||||
handler = BraintrustCallbackHandler()
|
||||
set_global_handler(handler)
|
||||
set_trace_processors([BraintrustTracingProcessor(logger)])
|
||||
|
||||
0
backend/onyx/feature_flags/__init__.py
Normal file
0
backend/onyx/feature_flags/__init__.py
Normal file
28
backend/onyx/feature_flags/factory.py
Normal file
28
backend/onyx/feature_flags/factory.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
from onyx.feature_flags.interface import NoOpFeatureFlagProvider
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_default_feature_flag_provider() -> FeatureFlagProvider:
|
||||
"""
|
||||
Get the default feature flag provider implementation.
|
||||
|
||||
Returns the PostHog-based provider in Enterprise Edition when available,
|
||||
otherwise returns a no-op provider that always returns False.
|
||||
|
||||
This function is designed for dependency injection - callers should
|
||||
use this factory rather than directly instantiating providers.
|
||||
|
||||
Returns:
|
||||
FeatureFlagProvider: The configured feature flag provider instance
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.feature_flags.factory",
|
||||
attribute="get_posthog_feature_flag_provider",
|
||||
fallback=lambda: NoOpFeatureFlagProvider(),
|
||||
)()
|
||||
return NoOpFeatureFlagProvider()
|
||||
6
backend/onyx/feature_flags/feature_flags_keys.py
Normal file
6
backend/onyx/feature_flags/feature_flags_keys.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Feature flag keys used throughout the application.
|
||||
Centralizes feature flag key definitions to avoid magic strings.
|
||||
"""
|
||||
|
||||
SIMPLE_AGENT_FRAMEWORK = "simple-agent-framework"
|
||||
0
backend/onyx/feature_flags/flags.py
Normal file
0
backend/onyx/feature_flags/flags.py
Normal file
72
backend/onyx/feature_flags/interface.py
Normal file
72
backend/onyx/feature_flags/interface.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import abc
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import ENVIRONMENT
|
||||
|
||||
|
||||
class FeatureFlagProvider(abc.ABC):
|
||||
"""
|
||||
Abstract base class for feature flag providers.
|
||||
|
||||
Implementations should provide vendor-specific logic for checking
|
||||
whether a feature flag is enabled for a given user.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def feature_enabled(
|
||||
self,
|
||||
flag_key: str,
|
||||
user_id: UUID,
|
||||
user_properties: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user.
|
||||
|
||||
Args:
|
||||
flag_key: The identifier for the feature flag to check
|
||||
user_id: The unique identifier for the user
|
||||
user_properties: Optional dictionary of user properties/attributes
|
||||
that may influence flag evaluation
|
||||
|
||||
Returns:
|
||||
True if the feature is enabled for the user, False otherwise
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def feature_enabled_for_user_tenant(
|
||||
self, flag_key: str, user: User | None, tenant_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user.
|
||||
"""
|
||||
return self.feature_enabled(
|
||||
flag_key,
|
||||
# For local dev with AUTH_TYPE=disabled, we don't have a user, so we use a random UUID
|
||||
user.id if user else UUID("caa1e0cd-6ee6-4550-b1ec-8affaef4bf83"),
|
||||
user_properties={
|
||||
"tenant_id": tenant_id,
|
||||
"email": user.email if user else "anonymous@onyx.app",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class NoOpFeatureFlagProvider(FeatureFlagProvider):
|
||||
"""
|
||||
No-operation feature flag provider that always returns False.
|
||||
|
||||
Used as a fallback when no real feature flag provider is available
|
||||
(e.g., in MIT version without PostHog).
|
||||
"""
|
||||
|
||||
def feature_enabled(
|
||||
self,
|
||||
flag_key: str,
|
||||
user_id: UUID,
|
||||
user_properties: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
environment = ENVIRONMENT
|
||||
if environment == "local":
|
||||
return True
|
||||
return False
|
||||
@@ -5,8 +5,8 @@ All other modules should import litellm from here instead of directly.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from agents.extensions.models.litellm_model import LitellmModel
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
|
||||
# Import litellm
|
||||
|
||||
@@ -16,8 +16,5 @@ from onyx.configs.app_configs import BRAINTRUST_ENABLED
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
if BRAINTRUST_ENABLED:
|
||||
litellm.callbacks = ["braintrust"]
|
||||
|
||||
# Export the configured litellm module
|
||||
__all__ = ["litellm"]
|
||||
__all__ = ["litellm", "LitellmModel"]
|
||||
|
||||
@@ -57,6 +57,22 @@ class PreviousMessage(BaseModel):
|
||||
research_answer_purpose=chat_message.research_answer_purpose,
|
||||
)
|
||||
|
||||
def to_agent_sdk_msg(self) -> dict:
|
||||
message_type_to_agent_sdk_role = {
|
||||
MessageType.USER: "user",
|
||||
MessageType.SYSTEM: "system",
|
||||
MessageType.ASSISTANT: "assistant",
|
||||
}
|
||||
# TODO: Use native format for files and images
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type in message_type_to_agent_sdk_role:
|
||||
role = message_type_to_agent_sdk_role[self.message_type]
|
||||
return {
|
||||
"role": role,
|
||||
"content": content,
|
||||
}
|
||||
raise ValueError(f"Unknown message type: {self.message_type}")
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
@@ -66,6 +82,7 @@ class PreviousMessage(BaseModel):
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
|
||||
# TODO: deprecate langchain
|
||||
@classmethod
|
||||
def from_langchain_msg(
|
||||
cls, msg: BaseMessage, token_count: int
|
||||
|
||||
@@ -10,6 +10,13 @@ Avoid using double brackets like [[1]]. To cite multiple documents, use [1], [2]
|
||||
Try to cite inline as opposed to leaving all citations until the very end of the response.
|
||||
""".rstrip()
|
||||
|
||||
REQUIRE_CITATION_STATEMENT_V2 = """
|
||||
Cite relevant statements INLINE using the format [[1]](https://example.com) with the document number (an integer) in between
|
||||
the brackets. To cite multiple documents, use [[1]](https://example.com), [[2]](https://example.com) format instead of \
|
||||
[[1, 2]](https://example.com). \
|
||||
Try to cite inline as opposed to leaving all citations until the very end of the response.
|
||||
""".rstrip()
|
||||
|
||||
NO_CITATION_STATEMENT = """
|
||||
Do not provide any citations even if there are examples in the chat history.
|
||||
""".rstrip()
|
||||
|
||||
@@ -54,6 +54,12 @@ CONVERSATION HISTORY:
|
||||
{GENERAL_SEP_PAT}
|
||||
"""
|
||||
|
||||
COMPANY_NAME_BLOCK = """
|
||||
The user works at {company_name}.
|
||||
"""
|
||||
|
||||
COMPANY_DESCRIPTION_BLOCK = """Organization description: {company_description}
|
||||
"""
|
||||
|
||||
# This has to be doubly escaped due to json containing { } which are also used for format strings
|
||||
EMPTY_SAMPLE_JSON = {
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.db.models import Persona
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.prompts.direct_qa_prompts import COMPANY_DESCRIPTION_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import COMPANY_NAME_BLOCK
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -89,6 +92,23 @@ def handle_onyx_date_awareness(
|
||||
return prompt_str
|
||||
|
||||
|
||||
def handle_company_awareness(prompt_str: str) -> str:
|
||||
try:
|
||||
workspace_settings = load_settings()
|
||||
company_name = workspace_settings.company_name
|
||||
company_description = workspace_settings.company_description
|
||||
if company_name:
|
||||
prompt_str += COMPANY_NAME_BLOCK.format(company_name=company_name)
|
||||
if company_description:
|
||||
prompt_str += COMPANY_DESCRIPTION_BLOCK.format(
|
||||
company_description=company_description
|
||||
)
|
||||
return prompt_str
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling company awareness: {e}")
|
||||
return prompt_str
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Persona | PromptConfig,
|
||||
use_language_hint: bool,
|
||||
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import NamedTuple
|
||||
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +17,18 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
class PermissionSyncResult(NamedTuple):
|
||||
"""Result of a permission sync operation.
|
||||
|
||||
Attributes:
|
||||
num_updated: Number of documents successfully updated
|
||||
num_errors: Number of documents that failed to update
|
||||
"""
|
||||
|
||||
num_updated: int
|
||||
num_errors: int
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
@@ -159,7 +172,12 @@ class RedisConnectorPermissionSync:
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
task_logger: Logger | None = None,
|
||||
) -> int | None:
|
||||
) -> PermissionSyncResult:
|
||||
"""Update permissions for documents.
|
||||
|
||||
Returns:
|
||||
PermissionSyncResult containing counts of successful updates and errors
|
||||
"""
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
document_update_permissions_fn = fetch_versioned_implementation(
|
||||
@@ -168,6 +186,7 @@ class RedisConnectorPermissionSync:
|
||||
)
|
||||
|
||||
num_permissions = 0
|
||||
num_errors = 0
|
||||
# Create a task for each document permission sync
|
||||
for permissions in new_permissions:
|
||||
current_time = time.monotonic()
|
||||
@@ -201,14 +220,25 @@ class RedisConnectorPermissionSync:
|
||||
# a rare enough case to be acceptable.
|
||||
|
||||
# This can internally exception due to db issues but still continue
|
||||
# we may want to change this
|
||||
document_update_permissions_fn(
|
||||
self.tenant_id, permissions, source_string, connector_id, credential_id
|
||||
)
|
||||
# Catch exceptions per-document to avoid breaking the entire sync
|
||||
try:
|
||||
document_update_permissions_fn(
|
||||
self.tenant_id,
|
||||
permissions,
|
||||
source_string,
|
||||
connector_id,
|
||||
credential_id,
|
||||
)
|
||||
num_permissions += 1
|
||||
except Exception:
|
||||
num_errors += 1
|
||||
if task_logger:
|
||||
task_logger.exception(
|
||||
f"Failed to update permissions for document {permissions.doc_id}"
|
||||
)
|
||||
# Continue processing other documents
|
||||
|
||||
num_permissions += 1
|
||||
|
||||
return num_permissions
|
||||
return PermissionSyncResult(num_updated=num_permissions, num_errors=num_errors)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
@@ -91,7 +91,7 @@ from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import delete_service_account_credentials
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs_batched_parallel
|
||||
from onyx.db.document import get_document_counts_for_all_cc_pairs
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -795,15 +795,7 @@ def get_connector_indexing_status(
|
||||
list[IndexAttempt], latest_finished_index_attempts
|
||||
)
|
||||
|
||||
document_count_info = get_document_counts_for_cc_pairs_batched_parallel(
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
for cc_pair in non_editable_cc_pairs + editable_cc_pairs
|
||||
]
|
||||
)
|
||||
document_count_info = get_document_counts_for_all_cc_pairs(db_session)
|
||||
|
||||
# Create lookup dictionaries for efficient access
|
||||
cc_pair_to_document_cnt: dict[tuple[int, int], int] = {
|
||||
|
||||
@@ -22,6 +22,7 @@ from mcp.shared.auth import OAuthClientInformationFull
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
from mcp.shared.auth import OAuthToken
|
||||
from mcp.types import InitializeResult
|
||||
from mcp.types import Tool as MCPLibTool
|
||||
from pydantic import AnyUrl
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -50,6 +51,7 @@ from onyx.db.mcp import update_mcp_server__no_commit
|
||||
from onyx.db.mcp import upsert_user_connection_config
|
||||
from onyx.db.models import MCPConnectionConfig
|
||||
from onyx.db.models import MCPServer as DbMCPServer
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.tools import create_tool__no_commit
|
||||
from onyx.db.tools import delete_tool__no_commit
|
||||
@@ -1044,6 +1046,55 @@ def user_list_mcp_tools_by_id(
|
||||
return _list_mcp_tools_by_id(server_id, db, False, user)
|
||||
|
||||
|
||||
def _upsert_db_tools(
|
||||
discovered_tools: list[MCPLibTool],
|
||||
existing_by_name: dict[str, Tool],
|
||||
processed_names: set[str],
|
||||
mcp_server_id: int,
|
||||
db: Session,
|
||||
) -> bool:
|
||||
db_dirty = False
|
||||
|
||||
for tool in discovered_tools:
|
||||
tool_name = tool.name
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
processed_names.add(tool_name)
|
||||
description = tool.description or ""
|
||||
annotations_title = tool.annotations.title if tool.annotations else None
|
||||
display_name = tool.title or annotations_title or tool_name
|
||||
input_schema = tool.inputSchema
|
||||
|
||||
if existing_tool := existing_by_name.get(tool_name):
|
||||
if existing_tool.description != description:
|
||||
existing_tool.description = description
|
||||
db_dirty = True
|
||||
if existing_tool.display_name != display_name:
|
||||
existing_tool.display_name = display_name
|
||||
db_dirty = True
|
||||
if existing_tool.mcp_input_schema != input_schema:
|
||||
existing_tool.mcp_input_schema = input_schema
|
||||
db_dirty = True
|
||||
continue
|
||||
|
||||
new_tool = create_tool__no_commit(
|
||||
name=tool_name,
|
||||
description=description,
|
||||
openapi_schema=None,
|
||||
custom_headers=None,
|
||||
user_id=None,
|
||||
db_session=db,
|
||||
passthrough_auth=False,
|
||||
mcp_server_id=mcp_server_id,
|
||||
enabled=False,
|
||||
)
|
||||
new_tool.display_name = display_name
|
||||
new_tool.mcp_input_schema = input_schema
|
||||
db_dirty = True
|
||||
return db_dirty
|
||||
|
||||
|
||||
def _list_mcp_tools_by_id(
|
||||
server_id: int,
|
||||
db: Session,
|
||||
@@ -1090,18 +1141,35 @@ def _list_mcp_tools_by_id(
|
||||
logger.info(f"Discovering tools for MCP server: {mcp_server.name}: {t1}")
|
||||
# Normalize URL to include trailing slash to avoid redirect/slow path handling
|
||||
server_url = mcp_server.server_url.rstrip("/") + "/"
|
||||
tools = discover_mcp_tools(
|
||||
discovered_tools = discover_mcp_tools(
|
||||
server_url,
|
||||
connection_config.config.get("headers", {}) if connection_config else {},
|
||||
transport=mcp_server.transport,
|
||||
auth=auth,
|
||||
)
|
||||
logger.info(
|
||||
f"Discovered {len(tools)} tools for MCP server: {mcp_server.name}: {time.time() - t1}"
|
||||
f"Discovered {len(discovered_tools)} tools for MCP server: {mcp_server.name}: {time.time() - t1}"
|
||||
)
|
||||
|
||||
if is_admin:
|
||||
existing_tools = get_tools_by_mcp_server_id(mcp_server.id, db)
|
||||
existing_by_name = {db_tool.name: db_tool for db_tool in existing_tools}
|
||||
processed_names: set[str] = set()
|
||||
|
||||
db_dirty = _upsert_db_tools(
|
||||
discovered_tools, existing_by_name, processed_names, mcp_server.id, db
|
||||
)
|
||||
|
||||
for name, db_tool in existing_by_name.items():
|
||||
if name not in processed_names:
|
||||
delete_tool__no_commit(db_tool.id, db)
|
||||
db_dirty = True
|
||||
|
||||
if db_dirty:
|
||||
db.commit()
|
||||
|
||||
# Truncate tool descriptions to prevent overly long responses
|
||||
for tool in tools:
|
||||
for tool in discovered_tools:
|
||||
if tool.description:
|
||||
tool.description = _truncate_description(tool.description)
|
||||
|
||||
@@ -1112,7 +1180,7 @@ def _list_mcp_tools_by_id(
|
||||
server_id=server_id,
|
||||
server_name=mcp_server.name,
|
||||
server_url=mcp_server.server_url,
|
||||
tools=tools,
|
||||
tools=discovered_tools,
|
||||
)
|
||||
|
||||
|
||||
@@ -1324,70 +1392,31 @@ def _upsert_mcp_server(
|
||||
return mcp_server
|
||||
|
||||
|
||||
def _add_tools_to_server(
|
||||
def _sync_tools_for_server(
|
||||
mcp_server: DbMCPServer,
|
||||
selected_tools: list[str],
|
||||
keep_tool_names: set[str],
|
||||
user: User | None,
|
||||
selected_tools: set[str],
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
created_tools = 0
|
||||
# First, discover available tools from the server to get full definitions
|
||||
"""Toggle enabled state for MCP tools that exist for the server.
|
||||
Updates to the db model of a tool all happen when the user Lists Tools.
|
||||
This ensures that the the tools added to the db match what the user sees in the UI,
|
||||
even if the underlying tool has changed on the server after list tools is called.
|
||||
That's a corner case anyways; the admin should go back and update the server by re-listing tools.
|
||||
"""
|
||||
|
||||
connection_config = _get_connection_config(mcp_server, True, user, db_session)
|
||||
headers = connection_config.config.get("headers", {}) if connection_config else {}
|
||||
updated_tools = 0
|
||||
|
||||
auth = None
|
||||
if mcp_server.auth_type == MCPAuthenticationType.OAUTH:
|
||||
user_id = str(user.id) if user else ""
|
||||
assert connection_config
|
||||
auth = make_oauth_provider(
|
||||
mcp_server,
|
||||
user_id,
|
||||
UNUSED_RETURN_PATH,
|
||||
connection_config.id,
|
||||
mcp_server.admin_connection_config_id,
|
||||
)
|
||||
available_tools = discover_mcp_tools(
|
||||
mcp_server.server_url,
|
||||
headers,
|
||||
transport=mcp_server.transport,
|
||||
auth=auth,
|
||||
)
|
||||
tools_by_name = {tool.name: tool for tool in available_tools}
|
||||
existing_tools = get_tools_by_mcp_server_id(mcp_server.id, db_session)
|
||||
existing_by_name = {tool.name: tool for tool in existing_tools}
|
||||
|
||||
for tool_name in selected_tools:
|
||||
if tool_name not in tools_by_name:
|
||||
logger.warning(f"Tool '{tool_name}' not found in MCP server")
|
||||
continue
|
||||
# Disable any existing tools that were not processed above
|
||||
for tool_name, db_tool in existing_by_name.items():
|
||||
should_enable = tool_name in selected_tools
|
||||
if db_tool.enabled != should_enable:
|
||||
db_tool.enabled = should_enable
|
||||
updated_tools += 1
|
||||
|
||||
if tool_name in keep_tool_names:
|
||||
# tool was not deleted earlier and not added now
|
||||
continue
|
||||
|
||||
tool_def = tools_by_name[tool_name]
|
||||
|
||||
# Create Tool entry for each selected tool
|
||||
tool = create_tool__no_commit(
|
||||
name=tool_name,
|
||||
description=_truncate_description(tool_def.description),
|
||||
openapi_schema=None, # MCP tools don't use OpenAPI
|
||||
custom_headers=None,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=False,
|
||||
)
|
||||
|
||||
# Update the tool with MCP server ID, display name, and input schema
|
||||
tool.mcp_server_id = mcp_server.id
|
||||
annotations_title = tool_def.annotations.title if tool_def.annotations else None
|
||||
tool.display_name = tool_def.title or annotations_title or tool_name
|
||||
tool.mcp_input_schema = tool_def.inputSchema
|
||||
|
||||
created_tools += 1
|
||||
|
||||
logger.info(f"Created MCP tool '{tool.name}' with ID {tool.id}")
|
||||
return created_tools
|
||||
return updated_tools
|
||||
|
||||
|
||||
@admin_router.get("/servers/{server_id}", response_model=MCPServer)
|
||||
@@ -1560,27 +1589,12 @@ def update_mcp_server_with_tools(
|
||||
status_code=400, detail="MCP server has no admin connection config"
|
||||
)
|
||||
|
||||
# Cleanup: Delete tools for this server that are not in the selected_tools list
|
||||
selected_names = set(request.selected_tools or [])
|
||||
existing_tools = get_tools_by_mcp_server_id(request.server_id, db_session)
|
||||
keep_tool_names = set()
|
||||
updated_tools = 0
|
||||
for tool in existing_tools:
|
||||
if tool.name in selected_names:
|
||||
keep_tool_names.add(tool.name)
|
||||
else:
|
||||
delete_tool__no_commit(tool.id, db_session)
|
||||
updated_tools += 1
|
||||
# If selected_tools is provided, create individual tools for each
|
||||
|
||||
if request.selected_tools:
|
||||
updated_tools += _add_tools_to_server(
|
||||
mcp_server,
|
||||
request.selected_tools,
|
||||
keep_tool_names,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
updated_tools = _sync_tools_for_server(
|
||||
mcp_server,
|
||||
selected_names,
|
||||
db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ def create_custom_tool(
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
enabled=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return ToolSnapshot.from_model(tool)
|
||||
@@ -147,7 +148,7 @@ def list_tools(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> list[ToolSnapshot]:
|
||||
tools = get_tools(db_session)
|
||||
tools = get_tools(db_session, only_enabled=True)
|
||||
|
||||
filtered_tools: list[ToolSnapshot] = []
|
||||
for tool in tools:
|
||||
|
||||
@@ -15,6 +15,7 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
@@ -25,6 +26,7 @@ from onyx.chat.process_message import stream_chat_message
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -59,6 +61,7 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
@@ -816,3 +819,17 @@ async def search_chats(
|
||||
has_more=has_more,
|
||||
next_page=page + 1 if has_more else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop-chat-session/{chat_session_id}")
|
||||
def stop_chat_session(
|
||||
chat_session_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
redis_client: Redis = Depends(get_redis_client),
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Stop a chat session by setting a stop signal in Redis.
|
||||
This endpoint is called by the frontend when the user clicks the stop button.
|
||||
"""
|
||||
set_fence(chat_session_id, redis_client, True)
|
||||
return {"message": "Chat session stopped"}
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import RetrievalDocs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -255,6 +256,7 @@ class ChatMessageDetail(BaseModel):
|
||||
rephrased_query: str | None = None
|
||||
context_docs: RetrievalDocs | None = None
|
||||
message_type: MessageType
|
||||
research_type: ResearchType | None = None
|
||||
time_sent: datetime
|
||||
overridden_model: str | None
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
@@ -36,6 +36,12 @@ class MessageDelta(BaseObj):
|
||||
"""Control Packets"""
|
||||
|
||||
|
||||
class PacketException(BaseObj):
|
||||
type: Literal["error"] = "error"
|
||||
exception: Exception
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
class OverallStop(BaseObj):
|
||||
type: Literal["stop"] = "stop"
|
||||
|
||||
@@ -56,8 +62,14 @@ class SearchToolStart(BaseObj):
|
||||
class SearchToolDelta(BaseObj):
|
||||
type: Literal["internal_search_tool_delta"] = "internal_search_tool_delta"
|
||||
|
||||
queries: list[str] | None = None
|
||||
documents: list[SavedSearchDoc] | None = None
|
||||
queries: list[str]
|
||||
documents: list[SavedSearchDoc]
|
||||
|
||||
|
||||
class FetchToolStart(BaseObj):
|
||||
type: Literal["fetch_tool_start"] = "fetch_tool_start"
|
||||
|
||||
documents: list[SavedSearchDoc]
|
||||
|
||||
|
||||
class ImageGenerationToolStart(BaseObj):
|
||||
@@ -182,6 +194,8 @@ PacketObj = Annotated[
|
||||
ReasoningDelta,
|
||||
CitationStart,
|
||||
CitationDelta,
|
||||
PacketException,
|
||||
FetchToolStart,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -17,12 +17,15 @@ from onyx.db.chat import get_db_search_doc_by_id
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.feature_flags.factory import get_default_feature_flag_provider
|
||||
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import EndStepPacketList
|
||||
from onyx.server.query_and_chat.streaming_models import FetchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
@@ -45,6 +48,7 @@ from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
|
||||
@@ -213,9 +217,20 @@ def create_custom_tool_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetches: list[list[SavedSearchDoc]], step_nr: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
for fetch in fetches:
|
||||
packets.append(Packet(ind=step_nr, obj=FetchToolStart(documents=fetch)))
|
||||
packets.append(Packet(ind=step_nr, obj=SectionEnd(type="section_end")))
|
||||
step_nr += 1
|
||||
return packets
|
||||
|
||||
|
||||
def create_search_packets(
|
||||
search_queries: list[str],
|
||||
saved_search_docs: list[SavedSearchDoc] | None,
|
||||
saved_search_docs: list[SavedSearchDoc],
|
||||
is_internet_search: bool,
|
||||
step_nr: int,
|
||||
) -> list[Packet]:
|
||||
@@ -245,12 +260,253 @@ def create_search_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def translate_db_message_to_packets_simple(
|
||||
chat_message: ChatMessage,
|
||||
db_session: Session,
|
||||
remove_doc_content: bool = False,
|
||||
start_step_nr: int = 1,
|
||||
) -> EndStepPacketList:
|
||||
"""
|
||||
Translation function for simple agent framework (ResearchType.FAST).
|
||||
Includes support for FetchToolStart packets for web fetch operations.
|
||||
"""
|
||||
step_nr = start_step_nr
|
||||
packet_list: list[Packet] = []
|
||||
|
||||
if chat_message.message_type == MessageType.ASSISTANT:
|
||||
citations = chat_message.citations
|
||||
|
||||
citation_info_list: list[CitationInfo] = []
|
||||
if citations:
|
||||
for citation_num, search_doc_id in citations.items():
|
||||
search_doc = get_db_search_doc_by_id(search_doc_id, db_session)
|
||||
if search_doc:
|
||||
citation_info_list.append(
|
||||
CitationInfo(
|
||||
citation_num=citation_num,
|
||||
document_id=search_doc.document_id,
|
||||
)
|
||||
)
|
||||
elif chat_message.search_docs:
|
||||
for i, search_doc in enumerate(chat_message.search_docs):
|
||||
citation_info_list.append(
|
||||
CitationInfo(
|
||||
citation_num=i,
|
||||
document_id=search_doc.document_id,
|
||||
)
|
||||
)
|
||||
|
||||
research_iterations = []
|
||||
if chat_message.research_type in [
|
||||
ResearchType.THOUGHTFUL,
|
||||
ResearchType.DEEP,
|
||||
ResearchType.LEGACY_AGENTIC,
|
||||
ResearchType.FAST,
|
||||
]:
|
||||
research_iterations = sorted(
|
||||
chat_message.research_iterations, key=lambda x: x.iteration_nr
|
||||
)
|
||||
for research_iteration in research_iterations:
|
||||
if (
|
||||
research_iteration.iteration_nr > 1
|
||||
and research_iteration.reasoning
|
||||
and chat_message.research_type == ResearchType.DEEP
|
||||
):
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(research_iteration.reasoning, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if (
|
||||
research_iteration.purpose
|
||||
and chat_message.research_type == ResearchType.DEEP
|
||||
):
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(research_iteration.purpose, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
sub_steps = research_iteration.sub_steps
|
||||
tasks: list[str] = []
|
||||
tool_call_ids: list[int | None] = []
|
||||
cited_docs: list[SavedSearchDoc] = []
|
||||
fetches: list[list[SavedSearchDoc]] = []
|
||||
is_web_fetch: bool = False
|
||||
|
||||
for sub_step in sub_steps:
|
||||
# For v2 tools, use the queries field if available, otherwise fall back to sub_step_instructions
|
||||
if sub_step.queries:
|
||||
tasks.extend(sub_step.queries)
|
||||
else:
|
||||
tasks.append(sub_step.sub_step_instructions or "")
|
||||
tool_call_ids.append(sub_step.sub_step_tool_id)
|
||||
|
||||
sub_step_cited_docs = sub_step.cited_doc_results
|
||||
sub_step_saved_search_docs: list[SavedSearchDoc] = []
|
||||
if isinstance(sub_step_cited_docs, list):
|
||||
for doc_data in sub_step_cited_docs:
|
||||
doc_data["db_doc_id"] = 1
|
||||
doc_data["boost"] = 1
|
||||
doc_data["hidden"] = False
|
||||
doc_data["chunk_ind"] = 0
|
||||
|
||||
if (
|
||||
doc_data["updated_at"] is None
|
||||
or doc_data["updated_at"] == "None"
|
||||
):
|
||||
doc_data["updated_at"] = datetime.now()
|
||||
|
||||
sub_step_saved_search_docs.append(
|
||||
SavedSearchDoc.from_dict(doc_data)
|
||||
if isinstance(doc_data, dict)
|
||||
else doc_data
|
||||
)
|
||||
|
||||
cited_docs.extend(sub_step_saved_search_docs)
|
||||
else:
|
||||
if chat_message.research_type == ResearchType.DEEP:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if sub_step.is_web_fetch and len(sub_step_saved_search_docs) > 0:
|
||||
is_web_fetch = True
|
||||
fetches.append(sub_step_saved_search_docs)
|
||||
|
||||
if len(set(tool_call_ids)) > 1:
|
||||
if chat_message.research_type == ResearchType.DEEP:
|
||||
packet_list.extend(
|
||||
create_reasoning_packets(
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif len(sub_steps) == 0:
|
||||
# no sub steps, no tool calls. But iteration can have reasoning or purpose
|
||||
continue
|
||||
|
||||
else:
|
||||
tool_id = tool_call_ids[0]
|
||||
if not tool_id:
|
||||
raise ValueError("Tool ID is required")
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
tool_name = tool.name
|
||||
|
||||
if tool_name in [SearchTool.__name__, KnowledgeGraphTool.__name__]:
|
||||
cited_docs = cast(list[SavedSearchDoc], cited_docs)
|
||||
packet_list.extend(
|
||||
create_search_packets(tasks, cited_docs, False, step_nr)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif tool_name == WebSearchTool.__name__:
|
||||
cited_docs = cast(list[SavedSearchDoc], cited_docs)
|
||||
if is_web_fetch:
|
||||
packet_list.extend(create_fetch_packets(fetches, step_nr))
|
||||
else:
|
||||
packet_list.extend(
|
||||
create_search_packets(tasks, cited_docs, True, step_nr)
|
||||
)
|
||||
|
||||
step_nr += 1
|
||||
|
||||
elif tool_name == ImageGenerationTool.__name__:
|
||||
if sub_step.generated_images is None:
|
||||
raise ValueError("No generated images found")
|
||||
|
||||
packet_list.extend(
|
||||
create_image_generation_packets(
|
||||
sub_step.generated_images.images, step_nr
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
elif tool_name == OktaProfileTool.__name__:
|
||||
packet_list.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool_name,
|
||||
response_type="text",
|
||||
step_nr=step_nr,
|
||||
data=sub_step.sub_answer,
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
else:
|
||||
packet_list.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool_name,
|
||||
response_type="text",
|
||||
step_nr=step_nr,
|
||||
data=sub_step.sub_answer,
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if chat_message.message:
|
||||
packet_list.extend(
|
||||
create_message_packets(
|
||||
message_text=chat_message.message,
|
||||
final_documents=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in chat_message.search_docs
|
||||
],
|
||||
step_nr=step_nr,
|
||||
is_legacy_agentic=chat_message.research_type
|
||||
== ResearchType.LEGACY_AGENTIC,
|
||||
)
|
||||
)
|
||||
step_nr += 1
|
||||
|
||||
if len(citation_info_list) > 0 and len(research_iterations) == 0:
|
||||
saved_search_docs: list[SavedSearchDoc] = []
|
||||
for citation_info in citation_info_list:
|
||||
cited_doc = get_db_search_doc_by_document_id(
|
||||
citation_info.document_id, db_session
|
||||
)
|
||||
if cited_doc:
|
||||
saved_search_docs.append(
|
||||
translate_db_search_doc_to_server_search_doc(cited_doc)
|
||||
)
|
||||
|
||||
packet_list.extend(
|
||||
create_search_packets([], saved_search_docs, False, step_nr)
|
||||
)
|
||||
|
||||
step_nr += 1
|
||||
|
||||
return EndStepPacketList(packet_list=packet_list, end_step_nr=step_nr)
|
||||
|
||||
|
||||
def translate_db_message_to_packets(
|
||||
chat_message: ChatMessage,
|
||||
db_session: Session,
|
||||
remove_doc_content: bool = False,
|
||||
start_step_nr: int = 1,
|
||||
) -> EndStepPacketList:
|
||||
use_simple_translation = False
|
||||
if chat_message.research_type and chat_message.research_type != ResearchType.DEEP:
|
||||
feature_flag_provider = get_default_feature_flag_provider()
|
||||
tenant_id = get_current_tenant_id()
|
||||
user = chat_message.chat_session.user
|
||||
use_simple_translation = feature_flag_provider.feature_enabled_for_user_tenant(
|
||||
flag_key=SIMPLE_AGENT_FRAMEWORK,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if use_simple_translation:
|
||||
return translate_db_message_to_packets_simple(
|
||||
chat_message=chat_message,
|
||||
db_session=db_session,
|
||||
remove_doc_content=remove_doc_content,
|
||||
start_step_nr=start_step_nr,
|
||||
)
|
||||
|
||||
step_nr = start_step_nr
|
||||
packet_list: list[Packet] = []
|
||||
|
||||
@@ -305,7 +561,11 @@ def translate_db_message_to_packets(
|
||||
cited_docs: list[SavedSearchDoc] = []
|
||||
|
||||
for sub_step in sub_steps:
|
||||
tasks.append(sub_step.sub_step_instructions or "")
|
||||
# For v2 tools, use the queries field if available, otherwise fall back to sub_step_instructions
|
||||
if sub_step.queries:
|
||||
tasks.extend(sub_step.queries)
|
||||
else:
|
||||
tasks.append(sub_step.sub_step_instructions or "")
|
||||
tool_call_ids.append(sub_step.sub_step_tool_id)
|
||||
|
||||
sub_step_cited_docs = sub_step.cited_doc_results
|
||||
|
||||
@@ -45,6 +45,8 @@ class Settings(BaseModel):
|
||||
|
||||
# is float to allow for fractional days for easier automated testing
|
||||
maximum_chat_retention_days: float | None = None
|
||||
company_name: str | None = None
|
||||
company_description: str | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
123
backend/onyx/tools/adapter_v1_to_v2.py
Normal file
123
backend/onyx/tools/adapter_v1_to_v2.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# create adapter from Tool to FunctionTool
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Union
|
||||
|
||||
from agents import FunctionTool
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.built_in_tools_v2 import BUILT_IN_TOOL_MAP_V2
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
|
||||
# Type alias for tools that need custom handling
|
||||
CustomOrMcpTool = Union[CustomTool, MCPTool]
|
||||
|
||||
|
||||
def is_custom_or_mcp_tool(tool: Tool) -> bool:
|
||||
"""Check if a tool is a CustomTool or MCPTool."""
|
||||
return isinstance(tool, CustomTool) or isinstance(tool, MCPTool)
|
||||
|
||||
|
||||
@tool_accounting
|
||||
async def _tool_run_wrapper(
|
||||
run_context: RunContextWrapper[ChatTurnContext], tool: Tool, json_string: str
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Wrapper function to adapt Tool.run() to FunctionTool.on_invoke_tool() signature.
|
||||
"""
|
||||
args = json.loads(json_string) if json_string else {}
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolStart(type="custom_tool_start", tool_name=tool.name),
|
||||
)
|
||||
)
|
||||
results = []
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan=f"Running {tool.name}",
|
||||
purpose=f"Running {tool.name}",
|
||||
reasoning=f"Running {tool.name}",
|
||||
)
|
||||
)
|
||||
for result in tool.run(**args):
|
||||
results.append(result)
|
||||
# Extract data from CustomToolCallSummary within the ToolResponse
|
||||
custom_summary = result.response
|
||||
data = None
|
||||
file_ids = None
|
||||
|
||||
# Handle different response types
|
||||
if custom_summary.response_type in ["image", "csv"] and hasattr(
|
||||
custom_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = custom_summary.tool_result.file_ids
|
||||
else:
|
||||
data = custom_summary.tool_result
|
||||
run_context.context.aggregated_context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=tool.name,
|
||||
tool_id=tool.id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=json.dumps(args) if args else "",
|
||||
reasoning=f"Running {tool.name}",
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
cited_documents={},
|
||||
additional_data=None,
|
||||
response_type=custom_summary.response_type,
|
||||
answer=str(data) if data else str(file_ids),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolDelta(
|
||||
type="custom_tool_delta",
|
||||
tool_name=tool.name,
|
||||
response_type=custom_summary.response_type,
|
||||
data=data,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def tool_to_function_tool(tool: Tool) -> FunctionTool:
|
||||
return FunctionTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
params_json_schema=tool.tool_definition()["function"]["parameters"],
|
||||
on_invoke_tool=lambda context, json_string: _tool_run_wrapper(
|
||||
context, tool, json_string
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def tools_to_function_tools(tools: list[Tool]) -> list[FunctionTool]:
|
||||
onyx_tools: list[list[FunctionTool]] = [
|
||||
BUILT_IN_TOOL_MAP_V2[type(tool).__name__]
|
||||
for tool in tools
|
||||
if type(tool).__name__ in BUILT_IN_TOOL_MAP_V2
|
||||
]
|
||||
flattened_builtin_tools: list[FunctionTool] = [
|
||||
onyx_tool for sublist in onyx_tools for onyx_tool in sublist
|
||||
]
|
||||
custom_and_mcp_tools: list[FunctionTool] = [
|
||||
tool_to_function_tool(tool) for tool in tools if is_custom_or_mcp_tool(tool)
|
||||
]
|
||||
|
||||
return flattened_builtin_tools + custom_and_mcp_tools
|
||||
24
backend/onyx/tools/built_in_tools_v2.py
Normal file
24
backend/onyx/tools/built_in_tools_v2.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from agents import FunctionTool
|
||||
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
|
||||
from onyx.tools.tool_implementations_v2.internal_search import internal_search_tool
|
||||
from onyx.tools.tool_implementations_v2.okta_profile import okta_profile_tool
|
||||
from onyx.tools.tool_implementations_v2.web import web_fetch_tool
|
||||
from onyx.tools.tool_implementations_v2.web import web_search_tool
|
||||
|
||||
BUILT_IN_TOOL_MAP_V2: dict[str, list[FunctionTool]] = {
|
||||
SearchTool.__name__: [internal_search_tool],
|
||||
ImageGenerationTool.__name__: [image_generation_tool],
|
||||
WebSearchTool.__name__: [web_search_tool, web_fetch_tool],
|
||||
OktaProfileTool.__name__: [okta_profile_tool],
|
||||
}
|
||||
@@ -26,7 +26,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_async_sync
|
||||
from onyx.utils.threadpool_concurrency import run_async_sync_no_cancel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -203,7 +203,7 @@ def _call_mcp_client_function_sync(
|
||||
function, server_url, connection_headers, transport, auth, **kwargs
|
||||
)
|
||||
try:
|
||||
return run_async_sync(run_client_function())
|
||||
return run_async_sync_no_cancel(run_client_function())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to call MCP client function: {e}")
|
||||
if isinstance(e, ExceptionGroup):
|
||||
|
||||
162
backend/onyx/tools/tool_implementations_v2/image_generation.py
Normal file
162
backend/onyx/tools/tool_implementations_v2/image_generation.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from typing import cast
|
||||
|
||||
from agents import function_tool
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _image_generation_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
prompt: str,
|
||||
shape: str,
|
||||
image_generation_tool_instance: ImageGenerationTool,
|
||||
) -> list[GeneratedImage]:
|
||||
index = run_context.context.current_run_step
|
||||
emitter = run_context.context.run_dependencies.emitter
|
||||
|
||||
# Emit start event
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
|
||||
)
|
||||
)
|
||||
|
||||
# Prepare tool arguments
|
||||
tool_args = {"prompt": prompt}
|
||||
if shape != "square": # Only include shape if it's not the default
|
||||
tool_args["shape"] = shape
|
||||
|
||||
# Run the actual image generation tool with heartbeat handling
|
||||
generated_images: list[GeneratedImage] = []
|
||||
heartbeat_count = 0
|
||||
|
||||
for tool_response in image_generation_tool_instance.run(
|
||||
**tool_args # type: ignore[arg-type]
|
||||
):
|
||||
# Handle heartbeat responses
|
||||
if tool_response.id == "image_generation_heartbeat":
|
||||
# Emit heartbeat event for every iteration
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolHeartbeat(
|
||||
type="image_generation_tool_heartbeat"
|
||||
),
|
||||
)
|
||||
)
|
||||
heartbeat_count += 1
|
||||
logger.debug(f"Image generation heartbeat #{heartbeat_count}")
|
||||
continue
|
||||
|
||||
# Process the tool response to get the generated images
|
||||
if tool_response.id == "image_generation_response":
|
||||
image_generation_responses = cast(
|
||||
list[ImageGenerationResponse], tool_response.response
|
||||
)
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in image_generation_responses if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in image_generation_responses
|
||||
if img.image_data
|
||||
],
|
||||
)
|
||||
generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=img.url if img.url else build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
)
|
||||
for img, file_id in zip(image_generation_responses, file_ids)
|
||||
]
|
||||
break
|
||||
|
||||
if not generated_images:
|
||||
raise RuntimeError("No images were generated")
|
||||
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="Generating images",
|
||||
purpose="Generating images",
|
||||
reasoning="Generating images",
|
||||
)
|
||||
)
|
||||
run_context.context.aggregated_context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=image_generation_tool_instance.name,
|
||||
tool_id=image_generation_tool_instance.id,
|
||||
iteration_nr=run_context.context.current_run_step,
|
||||
parallelization_nr=0,
|
||||
question=prompt,
|
||||
answer="",
|
||||
reasoning="",
|
||||
claims=[],
|
||||
generated_images=generated_images,
|
||||
additional_data={},
|
||||
response_type=None,
|
||||
data=None,
|
||||
file_ids=None,
|
||||
cited_documents={},
|
||||
)
|
||||
)
|
||||
# Emit final result
|
||||
emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=ImageGenerationToolDelta(
|
||||
type="image_generation_tool_delta", images=generated_images
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return generated_images
|
||||
|
||||
|
||||
# failure_error_function=None causes error to be re-raised instead of passing error
|
||||
# message back to the LLM. This is needed for image_generation since we configure our agent
|
||||
# to stop at this tool.
|
||||
@function_tool(failure_error_function=None)
|
||||
def image_generation_tool(
|
||||
run_context: RunContextWrapper[ChatTurnContext], prompt: str, shape: str = "square"
|
||||
) -> str:
|
||||
"""
|
||||
Generate an image from a text prompt using AI image generation models.
|
||||
|
||||
Args:
|
||||
prompt: The text description of the image to generate
|
||||
shape: The desired image shape - 'square', 'portrait', or 'landscape'
|
||||
"""
|
||||
image_generation_tool_instance = (
|
||||
run_context.context.run_dependencies.image_generation_tool
|
||||
)
|
||||
assert image_generation_tool_instance is not None
|
||||
|
||||
generated_images: list[GeneratedImage] = _image_generation_core(
|
||||
run_context, prompt, shape, image_generation_tool_instance
|
||||
)
|
||||
|
||||
# We should stop after this tool is called, so it doesn't matter what it returns
|
||||
return f"Successfully generated {len(generated_images)} images"
|
||||
197
backend/onyx/tools/tool_implementations_v2/internal_search.py
Normal file
197
backend/onyx/tools/tool_implementations_v2/internal_search.py
Normal file
@@ -0,0 +1,197 @@
|
||||
from typing import cast
|
||||
|
||||
from agents import function_tool
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.agents.agent_search.dr.models import InferenceSection
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stop_signal_checker import is_connected
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _internal_search_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
queries: list[str],
|
||||
search_tool: SearchTool,
|
||||
) -> list[LlmDoc]:
|
||||
"""Core internal search logic that can be tested with dependency injection"""
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolStart(
|
||||
type="internal_search_tool_start", is_internet_search=False
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta", queries=queries, documents=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Searching internally for information",
|
||||
reasoning=f"I am now using Internal Search to gather information on {queries}",
|
||||
)
|
||||
)
|
||||
|
||||
def execute_single_query(query: str, parallelization_nr: int) -> list[LlmDoc]:
|
||||
"""Execute a single query and return the retrieved documents as LlmDocs"""
|
||||
retrieved_llm_docs_for_query: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
skip_query_analysis=True,
|
||||
original_query=query,
|
||||
),
|
||||
):
|
||||
if not is_connected(
|
||||
run_context.context.chat_session_id,
|
||||
run_context.context.run_dependencies.redis_client,
|
||||
):
|
||||
break
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
# TODO: just a heuristic to not overload context window -- carried over from existing DR flow
|
||||
docs_to_feed_llm = 15
|
||||
retrieved_sections: list[InferenceSection] = response.top_sections[
|
||||
:docs_to_feed_llm
|
||||
]
|
||||
|
||||
# Convert InferenceSections to LlmDocs for return value
|
||||
retrieved_llm_docs_for_query = [
|
||||
section_to_llm_doc(section) for section in retrieved_sections
|
||||
]
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta",
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
retrieved_sections, is_internet=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.aggregated_context.cited_documents.extend(
|
||||
retrieved_sections
|
||||
)
|
||||
run_context.context.aggregated_context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=SearchTool.__name__,
|
||||
tool_id=get_tool_by_name(
|
||||
SearchTool.__name__,
|
||||
run_context.context.run_dependencies.db_session,
|
||||
).id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=query,
|
||||
reasoning=f"I am now using Internal Search to gather information on {query}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: inference_section
|
||||
for i, inference_section in enumerate(
|
||||
retrieved_sections
|
||||
)
|
||||
},
|
||||
queries=[query],
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return retrieved_llm_docs_for_query
|
||||
|
||||
# Execute all queries in parallel using run_functions_in_parallel
|
||||
function_calls = [
|
||||
FunctionCall(func=execute_single_query, args=(query, i))
|
||||
for i, query in enumerate(queries)
|
||||
]
|
||||
search_results_dict = run_functions_in_parallel(function_calls)
|
||||
|
||||
# Aggregate all results from all queries
|
||||
all_retrieved_docs: list[LlmDoc] = []
|
||||
for result_id in search_results_dict:
|
||||
retrieved_docs = search_results_dict[result_id]
|
||||
if retrieved_docs:
|
||||
all_retrieved_docs.extend(retrieved_docs)
|
||||
|
||||
return all_retrieved_docs
|
||||
|
||||
|
||||
@function_tool
|
||||
def internal_search_tool(
|
||||
run_context: RunContextWrapper[ChatTurnContext], queries: list[str]
|
||||
) -> str:
|
||||
"""
|
||||
Tool for searching over internal knowledge base from the user's connectors.
|
||||
The queries will be searched over a vector database where a hybrid search will be performed.
|
||||
Will return a combination of keyword and semantic search results.
|
||||
---
|
||||
## Decision boundary
|
||||
- MUST call internal_search_tool if the user's query requires internal information, like
|
||||
if it references "we" or "us" or "our" or "internal" or if it references
|
||||
the organization the user works for.
|
||||
|
||||
## Usage hints
|
||||
- Batch a list of natural-language queries per call.
|
||||
- Generally try searching with some semantic queries and some keyword queries
|
||||
to give the hybrid search the best chance of finding relevant results.
|
||||
|
||||
## Args
|
||||
- queries (list[str]): The search queries.
|
||||
|
||||
## Returns (list of LlmDoc objects as string)
|
||||
Each LlmDoc contains:
|
||||
- document_id: Unique document identifier
|
||||
- content: Full document content (combined from all chunks in the section)
|
||||
- blurb: Text excerpt from the document
|
||||
- semantic_identifier: Human-readable document name
|
||||
- source_type: Type of document source (e.g., web, confluence, etc.)
|
||||
- metadata: Additional document metadata
|
||||
- updated_at: When document was last updated
|
||||
- link: Primary URL to the source (may be None). Used for citations.
|
||||
- source_links: Dictionary of URLs to the source
|
||||
- match_highlights: Highlighted matching text snippets
|
||||
"""
|
||||
search_pipeline_instance = run_context.context.run_dependencies.search_pipeline
|
||||
if search_pipeline_instance is None:
|
||||
raise RuntimeError("Search tool not available in context")
|
||||
|
||||
# Call the core function
|
||||
retrieved_docs = _internal_search_core(
|
||||
run_context, queries, search_pipeline_instance
|
||||
)
|
||||
|
||||
return str(retrieved_docs)
|
||||
90
backend/onyx/tools/tool_implementations_v2/okta_profile.py
Normal file
90
backend/onyx/tools/tool_implementations_v2/okta_profile.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from agents import function_tool
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _okta_profile_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
okta_profile_tool_instance: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Core Okta profile logic that can be tested with dependency injection"""
|
||||
if okta_profile_tool_instance is None:
|
||||
raise RuntimeError("Okta profile tool not available in context")
|
||||
|
||||
index = run_context.context.current_run_step
|
||||
emitter = run_context.context.run_dependencies.emitter # type: ignore[union-attr]
|
||||
|
||||
# Emit start event
|
||||
emitter.emit( # type: ignore[union-attr]
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolStart(type="custom_tool_start", tool_name="Okta Profile"),
|
||||
)
|
||||
)
|
||||
|
||||
# Emit delta event for fetching profile
|
||||
emitter.emit( # type: ignore[union-attr]
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolDelta(
|
||||
type="custom_tool_delta",
|
||||
tool_name="Okta Profile",
|
||||
response_type="text",
|
||||
data="Fetching profile information...",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Run the actual Okta profile tool
|
||||
profile_data = None
|
||||
for tool_response in okta_profile_tool_instance.run():
|
||||
if tool_response.id == "okta_profile":
|
||||
profile_data = tool_response.response
|
||||
break
|
||||
|
||||
if profile_data is None:
|
||||
raise RuntimeError("No profile data was retrieved from Okta")
|
||||
|
||||
# Emit final result
|
||||
emitter.emit( # type: ignore[union-attr]
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=CustomToolDelta(
|
||||
type="custom_tool_delta",
|
||||
tool_name="Okta Profile",
|
||||
response_type="json",
|
||||
data=profile_data,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return profile_data
|
||||
|
||||
|
||||
@function_tool
|
||||
def okta_profile_tool(run_context: RunContextWrapper[ChatTurnContext]) -> str:
|
||||
"""
|
||||
Retrieve the current user's profile information from Okta.
|
||||
|
||||
This tool fetches user profile details including name, email, department,
|
||||
location, title, manager, and other profile information from the Okta identity provider.
|
||||
"""
|
||||
# Get the Okta profile tool from context
|
||||
okta_profile_tool_instance = run_context.context.run_dependencies.okta_profile_tool
|
||||
|
||||
# Call the core function
|
||||
profile_data = _okta_profile_core(run_context, okta_profile_tool_instance)
|
||||
|
||||
return json.dumps(profile_data)
|
||||
@@ -0,0 +1,79 @@
|
||||
import functools
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import RunContextWrapper
|
||||
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def tool_accounting(func: F) -> F:
|
||||
"""
|
||||
Decorator that adds tool accounting functionality to tool functions.
|
||||
Handles both sync and async functions automatically.
|
||||
|
||||
This decorator:
|
||||
1. Increments the current_run_step index at the beginning
|
||||
2. Emits a section end packet and increments current_run_step at the end
|
||||
3. Ensures the cleanup happens even if an exception occurs
|
||||
|
||||
Args:
|
||||
func: The function to decorate. Must take a RunContextWrapper[ChatTurnContext] as first argument.
|
||||
|
||||
Returns:
|
||||
The decorated function with tool accounting functionality.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(
|
||||
run_context: RunContextWrapper[ChatTurnContext], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
# Increment current_run_step at the beginning
|
||||
run_context.context.current_run_step += 1
|
||||
|
||||
try:
|
||||
# Call the original function
|
||||
result = func(run_context, *args, **kwargs)
|
||||
|
||||
# If it's a coroutine, we need to handle it differently
|
||||
if inspect.iscoroutine(result):
|
||||
# For async functions, we need to return a coroutine that handles the cleanup
|
||||
async def async_wrapper() -> Any:
|
||||
try:
|
||||
return await result
|
||||
finally:
|
||||
_emit_section_end(run_context)
|
||||
|
||||
return async_wrapper()
|
||||
else:
|
||||
# For sync functions, emit cleanup immediately
|
||||
_emit_section_end(run_context)
|
||||
return result
|
||||
|
||||
except Exception:
|
||||
# Always emit cleanup even if an exception occurred
|
||||
_emit_section_end(run_context)
|
||||
raise
|
||||
|
||||
return cast(F, wrapper)
|
||||
|
||||
|
||||
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
|
||||
"""Helper function to emit section end packet and increment current_run_step."""
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.current_run_step += 1
|
||||
312
backend/onyx/tools/tool_implementations_v2/web.py
Normal file
312
backend/onyx/tools/tool_implementations_v2/web.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from agents import function_tool
|
||||
from agents import RunContextWrapper
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
get_default_provider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_content,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
|
||||
dummy_inference_section_from_internet_search_result,
|
||||
)
|
||||
from onyx.chat.turn.models import ChatTurnContext
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.server.query_and_chat.streaming_models import FetchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
class WebSearchResult(BaseModel):
|
||||
tag: str
|
||||
title: str
|
||||
link: str
|
||||
snippet: str
|
||||
author: Optional[str] = None
|
||||
published_date: Optional[str] = None
|
||||
|
||||
|
||||
class WebSearchResponse(BaseModel):
|
||||
results: List[WebSearchResult]
|
||||
|
||||
|
||||
class WebFetchResult(BaseModel):
|
||||
tag: str
|
||||
title: str
|
||||
link: str
|
||||
full_content: str
|
||||
published_date: Optional[str] = None
|
||||
|
||||
|
||||
class WebFetchResponse(BaseModel):
|
||||
results: List[WebFetchResult]
|
||||
|
||||
|
||||
def short_tag(link: str, i: int) -> str:
|
||||
return f"{i+1}"
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _web_search_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
queries: list[str],
|
||||
search_provider: WebSearchProvider,
|
||||
) -> WebSearchResponse:
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
|
||||
index = run_context.context.current_run_step
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolStart(
|
||||
type="internal_search_tool_start", is_internet_search=True
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=SearchToolDelta(
|
||||
type="internal_search_tool_delta", queries=queries, documents=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
queries_str = ", ".join(queries)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Searching the web for information",
|
||||
reasoning=f"I am now using Web Search to gather information on {queries_str}",
|
||||
)
|
||||
)
|
||||
|
||||
# Search all queries in parallel
|
||||
function_calls = [
|
||||
FunctionCall(func=search_provider.search, args=(query,)) for query in queries
|
||||
]
|
||||
search_results_dict = run_functions_in_parallel(function_calls)
|
||||
|
||||
# Aggregate all results from all queries
|
||||
all_hits = []
|
||||
for result_id in search_results_dict:
|
||||
hits = search_results_dict[result_id]
|
||||
if hits:
|
||||
all_hits.extend(hits)
|
||||
|
||||
# Convert hits to WebSearchResult objects
|
||||
results = []
|
||||
for i, r in enumerate(all_hits):
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
tag=short_tag(r.link, i),
|
||||
title=r.title,
|
||||
link=r.link,
|
||||
snippet=r.snippet or "",
|
||||
author=r.author,
|
||||
published_date=(
|
||||
r.published_date.isoformat() if r.published_date else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Create inference sections from search results and add to cited documents
|
||||
inference_sections = [
|
||||
dummy_inference_section_from_internet_search_result(r) for r in all_hits
|
||||
]
|
||||
run_context.context.aggregated_context.cited_documents.extend(inference_sections)
|
||||
|
||||
run_context.context.aggregated_context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
tool=WebSearchTool.__name__,
|
||||
tool_id=get_tool_by_name(
|
||||
WebSearchTool.__name__, run_context.context.run_dependencies.db_session
|
||||
).id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=queries_str,
|
||||
reasoning=f"I am now using Web Search to gather information on {queries_str}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: inference_section
|
||||
for i, inference_section in enumerate(inference_sections)
|
||||
},
|
||||
claims=[],
|
||||
queries=queries,
|
||||
)
|
||||
)
|
||||
return WebSearchResponse(results=results)
|
||||
|
||||
|
||||
@function_tool
|
||||
def web_search_tool(
|
||||
run_context: RunContextWrapper[ChatTurnContext], queries: list[str]
|
||||
) -> str:
|
||||
"""
|
||||
Tool for searching the public internet. Useful for up to date information on PUBLIC knowledge.
|
||||
---
|
||||
## Decision boundary
|
||||
- You MUST call `web_search_tool` to discover sources when the request involves:
|
||||
- Fresh/unstable info (news, prices, laws, schedules, product specs, scores, exchange rates).
|
||||
- Recommendations, or any query where the specific sources matter.
|
||||
- Verifiable claims, quotes, or citations.
|
||||
- After ANY successful `web_search_tool` call that yields candidate URLs, you MUST call
|
||||
`web_fetch_tool` on the selected URLs BEFORE answering. Do NOT answer from snippets.
|
||||
|
||||
## When NOT to use
|
||||
- Casual chat, rewriting/summarizing user-provided text, or translation.
|
||||
- When the user already provided URLs (go straight to `web_fetch_tool`).
|
||||
|
||||
## Usage hints
|
||||
- Batch a list of natural-language queries per call.
|
||||
- Prefer searches for distinct intents; then batch-fetch best URLs.
|
||||
- Deduplicate domains/near-duplicates. Prefer recent, authoritative sources.
|
||||
|
||||
## Args
|
||||
- queries (list[str]): The search queries.
|
||||
|
||||
## Returns (JSON string)
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"tag": "short_ref",
|
||||
"title": "...",
|
||||
"link": "https://...",
|
||||
"author": "...",
|
||||
"published_date": "2025-10-01T12:34:56Z"
|
||||
// intentionally NO full content
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
search_provider = get_default_provider()
|
||||
if search_provider is None:
|
||||
raise ValueError("No search provider found")
|
||||
response = _web_search_core(run_context, queries, search_provider)
|
||||
return response.model_dump_json()
|
||||
|
||||
|
||||
@tool_accounting
|
||||
def _web_fetch_core(
|
||||
run_context: RunContextWrapper[ChatTurnContext],
|
||||
urls: List[str],
|
||||
search_provider: WebSearchProvider,
|
||||
) -> WebFetchResponse:
|
||||
# TODO: Find better way to track index that isn't so implicit
|
||||
# based on number of tool calls
|
||||
index = run_context.context.current_run_step
|
||||
|
||||
# Create SavedSearchDoc objects from URLs for the FetchToolStart event
|
||||
saved_search_docs = [SavedSearchDoc.from_url(url) for url in urls]
|
||||
|
||||
run_context.context.run_dependencies.emitter.emit(
|
||||
Packet(
|
||||
ind=index,
|
||||
obj=FetchToolStart(type="fetch_tool_start", documents=saved_search_docs),
|
||||
)
|
||||
)
|
||||
|
||||
docs = search_provider.contents(urls)
|
||||
out = []
|
||||
for i, d in enumerate(docs):
|
||||
out.append(
|
||||
WebFetchResult(
|
||||
tag=short_tag(d.link, i), # <-- add a tag
|
||||
title=d.title,
|
||||
link=d.link,
|
||||
full_content=d.full_content,
|
||||
published_date=(
|
||||
d.published_date.isoformat() if d.published_date else None
|
||||
),
|
||||
)
|
||||
)
|
||||
run_context.context.iteration_instructions.append(
|
||||
IterationInstructions(
|
||||
iteration_nr=index,
|
||||
plan="plan",
|
||||
purpose="Fetching content from URLs",
|
||||
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
|
||||
)
|
||||
)
|
||||
|
||||
run_context.context.aggregated_context.global_iteration_responses.append(
|
||||
IterationAnswer(
|
||||
# TODO: For now, we're using the web_search_tool_name since the web_fetch_tool_name is not a built-in tool
|
||||
tool=WebSearchTool.__name__,
|
||||
tool_id=get_tool_by_name(
|
||||
WebSearchTool.__name__, run_context.context.run_dependencies.db_session
|
||||
).id,
|
||||
iteration_nr=index,
|
||||
parallelization_nr=0,
|
||||
question=f"Fetch content from URLs: {', '.join(urls)}",
|
||||
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
|
||||
answer="",
|
||||
cited_documents={
|
||||
i: dummy_inference_section_from_internet_content(d)
|
||||
for i, d in enumerate(docs)
|
||||
},
|
||||
claims=[],
|
||||
is_web_fetch=True,
|
||||
)
|
||||
)
|
||||
|
||||
return WebFetchResponse(results=out)
|
||||
|
||||
|
||||
@function_tool
|
||||
def web_fetch_tool(
|
||||
run_context: RunContextWrapper[ChatTurnContext], urls: List[str]
|
||||
) -> str:
|
||||
"""
|
||||
Tool for fetching and extracting full content from web pages.
|
||||
|
||||
---
|
||||
## Decision boundary
|
||||
- You MUST use `web_fetch_tool` before quoting, citing, or relying on page content.
|
||||
- Use it whenever you already have URLs (from the user or from `web_search_tool`).
|
||||
- Do NOT answer questions based on search snippets alone.
|
||||
|
||||
## When NOT to use
|
||||
- If you do not yet have URLs (search first).
|
||||
|
||||
## Usage hints
|
||||
- Avoid many tiny calls; batch URLs (1–20) in one request.
|
||||
- Prefer primary, recent, and reputable sources.
|
||||
- If PDFs/long docs appear, still fetch; you may summarize sections explicitly.
|
||||
|
||||
## Args
|
||||
- urls (List[str]): Absolute URLs to retrieve.
|
||||
|
||||
## Returns (JSON string)
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"tag": "short_ref",
|
||||
"title": "...",
|
||||
"link": "https://...",
|
||||
"full_content": "...",
|
||||
"published_date": "2025-10-01T12:34:56Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
search_provider = get_default_provider()
|
||||
if search_provider is None:
|
||||
raise ValueError("No search provider found")
|
||||
response = _web_fetch_core(run_context, urls, search_provider)
|
||||
return response.model_dump_json()
|
||||
@@ -283,7 +283,7 @@ def run_functions_in_parallel(
|
||||
return results
|
||||
|
||||
|
||||
def run_async_sync(coro: Awaitable[T]) -> T:
|
||||
def run_async_sync_no_cancel(coro: Awaitable[T]) -> T:
|
||||
"""
|
||||
async-to-sync converter. Basically just executes asyncio.run in a separate thread.
|
||||
Which is probably somehow inefficient or not ideal but fine for now.
|
||||
|
||||
@@ -49,7 +49,7 @@ msal==1.28.0
|
||||
nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.99.5
|
||||
openai==1.107.1
|
||||
openpyxl==3.1.5
|
||||
passlib==1.7.4
|
||||
playwright==1.55.0
|
||||
@@ -105,5 +105,6 @@ sendgrid==6.11.0
|
||||
voyageai==0.2.3
|
||||
cohere==5.6.1
|
||||
exa_py==1.15.4
|
||||
braintrust==0.2.6
|
||||
braintrust-langchain==0.0.4
|
||||
braintrust[openai-agents]==0.2.6
|
||||
braintrust-langchain==0.0.4
|
||||
openai-agents==0.3.3
|
||||
@@ -34,3 +34,4 @@ types-retry==0.9.9.3
|
||||
types-setuptools==68.0.0.3
|
||||
types-urllib3==1.26.25.11
|
||||
voyageai==0.2.3
|
||||
ipykernel==6.29.5
|
||||
@@ -4,7 +4,7 @@ cohere==5.6.1
|
||||
fastapi==0.116.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
numpy==1.26.4
|
||||
openai==1.99.5
|
||||
openai==1.107.1
|
||||
pydantic==2.11.7
|
||||
retry==0.9.2
|
||||
safetensors==0.5.3
|
||||
|
||||
@@ -16,6 +16,9 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
|
||||
|
||||
def run_jobs() -> None:
|
||||
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
|
||||
use_lightweight = True
|
||||
|
||||
# command setup
|
||||
cmd_worker_primary = [
|
||||
"celery",
|
||||
@@ -45,20 +48,6 @@ def run_jobs() -> None:
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
|
||||
]
|
||||
|
||||
cmd_worker_docprocessing = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -72,45 +61,6 @@ def run_jobs() -> None:
|
||||
"--queues=docprocessing",
|
||||
]
|
||||
|
||||
cmd_worker_user_files = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"--queues=user_file_processing,user_file_project_sync",
|
||||
]
|
||||
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"--queues=monitoring",
|
||||
]
|
||||
|
||||
cmd_worker_kg_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.kg_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=kg_processing@%n",
|
||||
"--queues=kg_processing",
|
||||
]
|
||||
|
||||
cmd_worker_docfetching = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -132,6 +82,84 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
|
||||
# Prepare background worker commands based on mode
|
||||
if use_lightweight:
|
||||
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
|
||||
cmd_worker_background = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync",
|
||||
]
|
||||
background_workers = [("BACKGROUND", cmd_worker_background)]
|
||||
else:
|
||||
print("Starting workers in STANDARD mode (separate background workers)")
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning",
|
||||
]
|
||||
cmd_worker_kg_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.kg_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=kg_processing@%n",
|
||||
"-Q",
|
||||
"kg_processing",
|
||||
]
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
|
||||
]
|
||||
background_workers = [
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("KG_PROCESSING", cmd_worker_kg_processing),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
]
|
||||
|
||||
# spawn processes
|
||||
worker_primary_process = subprocess.Popen(
|
||||
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
@@ -141,10 +169,6 @@ def run_jobs() -> None:
|
||||
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_heavy_process = subprocess.Popen(
|
||||
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_docprocessing_process = subprocess.Popen(
|
||||
cmd_worker_docprocessing,
|
||||
stdout=subprocess.PIPE,
|
||||
@@ -152,27 +176,6 @@ def run_jobs() -> None:
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_user_file_process = subprocess.Popen(
|
||||
cmd_worker_user_files,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_monitoring_process = subprocess.Popen(
|
||||
cmd_worker_monitoring,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_kg_processing_process = subprocess.Popen(
|
||||
cmd_worker_kg_processing,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
@@ -184,6 +187,14 @@ def run_jobs() -> None:
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
# Spawn background worker processes based on mode
|
||||
background_processes = []
|
||||
for name, cmd in background_workers:
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
background_processes.append((name, process))
|
||||
|
||||
# monitor threads
|
||||
worker_primary_thread = threading.Thread(
|
||||
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||
@@ -191,47 +202,40 @@ def run_jobs() -> None:
|
||||
worker_light_thread = threading.Thread(
|
||||
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||
)
|
||||
worker_heavy_thread = threading.Thread(
|
||||
target=monitor_process, args=("HEAVY", worker_heavy_process)
|
||||
)
|
||||
worker_docprocessing_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
|
||||
)
|
||||
worker_user_file_thread = threading.Thread(
|
||||
target=monitor_process,
|
||||
args=("USER_FILE_PROCESSING", worker_user_file_process),
|
||||
)
|
||||
worker_monitoring_thread = threading.Thread(
|
||||
target=monitor_process, args=("MONITORING", worker_monitoring_process)
|
||||
)
|
||||
worker_kg_processing_thread = threading.Thread(
|
||||
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
# Create monitor threads for background workers
|
||||
background_threads = []
|
||||
for name, process in background_processes:
|
||||
thread = threading.Thread(target=monitor_process, args=(name, process))
|
||||
background_threads.append(thread)
|
||||
|
||||
# Start all threads
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_heavy_thread.start()
|
||||
worker_docprocessing_thread.start()
|
||||
worker_user_file_thread.start()
|
||||
worker_monitoring_thread.start()
|
||||
worker_kg_processing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
for thread in background_threads:
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_heavy_thread.join()
|
||||
worker_docprocessing_thread.join()
|
||||
worker_user_file_thread.join()
|
||||
worker_monitoring_thread.join()
|
||||
worker_kg_processing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
for thread in background_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_jobs()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user