Compare commits

...

1 Commits

Author SHA1 Message Date
Wenxi
cd96de146e hallo (#5738)
Co-authored-by: Richard Guan <41275416+rguan72@users.noreply.github.com>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: SubashMohan <subashmohan75@gmail.com>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: edwin-onyx <edwin@onyx.app>
Co-authored-by: Edwin Luo <edwin@parafin.com>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: Jessica Singh <86633231+jessicasingh7@users.noreply.github.com>
2025-10-15 18:34:12 -07:00
264 changed files with 15328 additions and 2493 deletions

View File

@@ -23,12 +23,10 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery background",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -42,16 +40,29 @@
}
},
{
"name": "Celery (all)",
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"configurations": [
"Celery primary",
"Celery light",
"Celery heavy",
"Celery kg_processing",
"Celery monitoring",
"Celery user_file_processing",
"Celery docfetching",
"Celery docprocessing",
"Celery beat",
"Celery monitoring",
"Celery user file processing"
"Celery beat"
],
"presentation": {
"group": "1"
@@ -199,6 +210,35 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -221,13 +261,100 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync"
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
},
{
"name": "Celery kg_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.kg_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=kg_processing@%n",
"-Q",
"kg_processing"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery kg_processing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user_file_processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console"
},
{
"name": "Celery docfetching",
"type": "debugpy",
@@ -311,58 +438,6 @@
},
"consoleTitle": "Celery beat Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery user file processing",
"type": "debugpy",
"request": "launch",
"module": "celery",
"args": [
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"--pool=threads",
"-Q",
"user_file_processing,user_file_project_sync"
],
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user file processing Console"
},
{
"name": "Pytest",
"consoleName": "Pytest",

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,6 +87,31 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (single worker process)
- Suitable for smaller deployments or development environments
- Default concurrency: 6 threads
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability

View File

@@ -70,7 +70,12 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Single thread (monitoring doesn't need parallelism)
- Cloud-specific monitoring tasks
8. **Beat Worker** (`beat`)
8. **User File Processing Worker** (`user_file_processing`)
- Processes user-uploaded files
- Handles user file indexing and project synchronization
- Configurable concurrency
9. **Beat Worker** (`beat`)
- Celery's scheduler for periodic tasks
- Uses DynamicTenantScheduler for multi-tenant support
- Schedules tasks like:
@@ -82,11 +87,39 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
- **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a
middleware layer that automatically finds the appropriate tenant ID when sending tasks
via Celery Beat.
- **Task Prioritization**: High, Medium, Low priority queues
- **Monitoring**: Built-in heartbeat and liveness checking

View File

@@ -111,6 +111,8 @@ COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
# Put logo in assets
COPY ./assets /app/assets

View File

@@ -0,0 +1,153 @@
"""add permission sync attempt tables
Revision ID: 03d710ccf29c
Revises: 96a5702df6aa
Create Date: 2025-09-11 13:30:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d710ccf29c" # Generate a new unique ID
down_revision = "96a5702df6aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the permission sync status enum
permission_sync_status_enum = sa.Enum(
"not_started",
"in_progress",
"success",
"canceled",
"failed",
"completed_with_errors",
name="permissionsyncstatus",
native_enum=False,
)
# Create doc_permission_sync_attempt table
op.create_table(
"doc_permission_sync_attempt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("status", permission_sync_status_enum, nullable=False),
sa.Column("total_docs_synced", sa.Integer(), nullable=True),
sa.Column("docs_with_permission_errors", sa.Integer(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for doc_permission_sync_attempt
op.create_index(
"ix_doc_permission_sync_attempt_time_created",
"doc_permission_sync_attempt",
["time_created"],
unique=False,
)
op.create_index(
"ix_permission_sync_attempt_latest_for_cc_pair",
"doc_permission_sync_attempt",
["connector_credential_pair_id", "time_created"],
unique=False,
)
op.create_index(
"ix_permission_sync_attempt_status_time",
"doc_permission_sync_attempt",
["status", sa.text("time_finished DESC")],
unique=False,
)
# Create external_group_permission_sync_attempt table
# connector_credential_pair_id is nullable - group syncs can be global (e.g., Confluence)
op.create_table(
"external_group_permission_sync_attempt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
sa.Column("status", permission_sync_status_enum, nullable=False),
sa.Column("total_users_processed", sa.Integer(), nullable=True),
sa.Column("total_groups_processed", sa.Integer(), nullable=True),
sa.Column("total_group_memberships_synced", sa.Integer(), nullable=True),
sa.Column("error_message", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Create indexes for external_group_permission_sync_attempt
op.create_index(
"ix_external_group_permission_sync_attempt_time_created",
"external_group_permission_sync_attempt",
["time_created"],
unique=False,
)
op.create_index(
"ix_group_sync_attempt_cc_pair_time",
"external_group_permission_sync_attempt",
["connector_credential_pair_id", "time_created"],
unique=False,
)
op.create_index(
"ix_group_sync_attempt_status_time",
"external_group_permission_sync_attempt",
["status", sa.text("time_finished DESC")],
unique=False,
)
def downgrade() -> None:
# Drop indexes
op.drop_index(
"ix_group_sync_attempt_status_time",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_group_sync_attempt_cc_pair_time",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_external_group_permission_sync_attempt_time_created",
table_name="external_group_permission_sync_attempt",
)
op.drop_index(
"ix_permission_sync_attempt_status_time",
table_name="doc_permission_sync_attempt",
)
op.drop_index(
"ix_permission_sync_attempt_latest_for_cc_pair",
table_name="doc_permission_sync_attempt",
)
op.drop_index(
"ix_doc_permission_sync_attempt_time_created",
table_name="doc_permission_sync_attempt",
)
# Drop tables
op.drop_table("external_group_permission_sync_attempt")
op.drop_table("doc_permission_sync_attempt")

View File

@@ -0,0 +1,37 @@
"""add queries and is web fetch to iteration answer
Revision ID: 6f4f86aef280
Revises: 96a5702df6aa
Create Date: 2025-10-14 18:08:30.920123
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "6f4f86aef280"
down_revision = "03d710ccf29c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_web_fetch column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
)
# Add queries column
op.add_column(
"research_agent_iteration_sub_step",
sa.Column("queries", postgresql.JSONB(), nullable=True),
)
def downgrade() -> None:
op.drop_column("research_agent_iteration_sub_step", "queries")
op.drop_column("research_agent_iteration_sub_step", "is_web_fetch")

View File

@@ -0,0 +1,45 @@
"""mcp_tool_enabled
Revision ID: 96a5702df6aa
Revises: 40926a4dab77
Create Date: 2025-10-09 12:10:21.733097
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "96a5702df6aa"
down_revision = "40926a4dab77"
branch_labels = None
depends_on = None
DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false"
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"enabled",
sa.Boolean(),
nullable=False,
server_default=sa.true(),
),
)
op.create_index(
"ix_tool_mcp_server_enabled",
"tool",
["mcp_server_id", "enabled"],
)
# Remove the server default so application controls defaulting
op.alter_column("tool", "enabled", server_default=None)
def downgrade() -> None:
op.execute(DELETE_DISABLED_TOOLS_SQL)
op.drop_index("ix_tool_mcp_server_enabled", table_name="tool")
op.drop_column("tool", "enabled")

View File

@@ -1,8 +1,13 @@
import json
from datetime import datetime
from enum import Enum
from functools import lru_cache
from typing import Any
from typing import cast
import jwt
import requests
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
@@ -10,6 +15,7 @@ from fastapi import status
from jwt import decode as jwt_decode
from jwt import InvalidTokenError
from jwt import PyJWTError
from jwt.algorithms import RSAAlgorithm
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -30,43 +36,156 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_PUBLIC_KEY_FETCH_ATTEMPTS = 2
class PublicKeyFormat(Enum):
JWKS = "jwks"
PEM = "pem"
@lru_cache()
def get_public_key() -> str | None:
def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None:
"""Fetch and cache the raw JWT verification material."""
if JWT_PUBLIC_KEY_URL is None:
logger.error("JWT_PUBLIC_KEY_URL is not set")
return None
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
return response.text
try:
response = requests.get(JWT_PUBLIC_KEY_URL)
response.raise_for_status()
except requests.RequestException as exc:
logger.error(f"Failed to fetch JWT public key: {str(exc)}")
return None
content_type = response.headers.get("Content-Type", "").lower()
raw_body = response.text
body_lstripped = raw_body.lstrip()
if "application/json" in content_type or body_lstripped.startswith("{"):
try:
data = response.json()
except ValueError:
logger.error("JWT public key URL returned invalid JSON")
return None
if isinstance(data, dict) and "keys" in data:
return data, PublicKeyFormat.JWKS
logger.error(
"JWT public key URL returned JSON but no JWKS 'keys' field was found"
)
return None
body = raw_body.strip()
if not body:
logger.error("JWT public key URL returned an empty response")
return None
return body, PublicKeyFormat.PEM
def get_public_key(token: str) -> RSAPublicKey | str | None:
"""Return the concrete public key used to verify the provided JWT token."""
payload = _fetch_public_key_payload()
if payload is None:
logger.error("Failed to retrieve public key payload")
return None
key_material, key_format = payload
if key_format is PublicKeyFormat.JWKS:
jwks_data = cast(dict[str, Any], key_material)
return _resolve_public_key_from_jwks(token, jwks_data)
return cast(str, key_material)
def _resolve_public_key_from_jwks(
token: str, jwks_payload: dict[str, Any]
) -> RSAPublicKey | None:
try:
header = jwt.get_unverified_header(token)
except PyJWTError as e:
logger.error(f"Unable to parse JWT header: {str(e)}")
return None
keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else []
if not keys:
logger.error("JWKS payload did not contain any keys")
return None
kid = header.get("kid")
thumbprint = header.get("x5t")
candidates = []
if kid:
candidates = [k for k in keys if k.get("kid") == kid]
if not candidates and thumbprint:
candidates = [k for k in keys if k.get("x5t") == thumbprint]
if not candidates and len(keys) == 1:
candidates = keys
if not candidates:
logger.warning(
"No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint
)
return None
if len(candidates) > 1:
logger.warning(
"Multiple JWKs matched token header kid=%s; selecting the first occurrence",
kid,
)
jwk = candidates[0]
try:
return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk)))
except ValueError as e:
logger.error(f"Failed to construct RSA key from JWK: {str(e)}")
return None
async def verify_jwt_token(token: str, async_db_session: AsyncSession) -> User | None:
try:
public_key_pem = get_public_key()
if public_key_pem is None:
logger.error("Failed to retrieve public key")
for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS):
public_key = get_public_key(token)
if public_key is None:
logger.error("Unable to resolve a public key for JWT verification")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
try:
payload = jwt_decode(
token,
public_key,
algorithms=["RS256"],
options={"verify_aud": False},
)
except InvalidTokenError as e:
logger.error(f"Invalid JWT token: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1:
_fetch_public_key_payload.cache_clear()
continue
return None
payload = jwt_decode(
token,
public_key_pem,
algorithms=["RS256"],
audience=None,
)
email = payload.get("email")
if email:
result = await async_db_session.execute(
select(User).where(func.lower(User.email) == func.lower(email))
)
return result.scalars().first()
except InvalidTokenError:
logger.error("Invalid JWT token")
get_public_key.cache_clear()
except PyJWTError as e:
logger.error(f"JWT decoding error: {str(e)}")
get_public_key.cache_clear()
logger.warning(
"JWT token decoded successfully but no email claim found; skipping auth"
)
break
return None

View File

@@ -0,0 +1,12 @@
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -1,123 +1,4 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.celery.apps.heavy import celery_app
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
celery_app.autodiscover_tasks(
@@ -125,5 +6,6 @@ celery_app.autodiscover_tasks(
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.query_history",
]
)

View File

@@ -56,6 +56,12 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt
from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt
from onyx.db.permission_sync_attempt import mark_doc_permission_sync_attempt_failed
from onyx.db.permission_sync_attempt import (
mark_doc_permission_sync_attempt_in_progress,
)
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
@@ -113,6 +119,14 @@ def _get_fence_validation_block_expiration() -> int:
"""Jobs / utils for kicking off doc permissions sync tasks."""
def _fail_doc_permission_sync_attempt(attempt_id: int, error_msg: str) -> None:
"""Helper to mark a doc permission sync attempt as failed with an error message."""
with get_session_with_current_tenant() as db_session:
mark_doc_permission_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
@@ -379,6 +393,15 @@ def connector_permission_sync_generator_task(
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
with get_session_with_current_tenant() as db_session:
attempt_id = create_doc_permission_sync_attempt(
connector_credential_pair_id=cc_pair_id,
db_session=db_session,
)
task_logger.info(
f"Created doc permission sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
@@ -389,22 +412,28 @@ def connector_permission_sync_generator_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
error_msg = (
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
if not redis_connector.permissions.fenced: # The fence must exist
raise ValueError(
error_msg = (
f"connector_permission_sync_generator_task - fence not found: "
f"fence={redis_connector.permissions.fence_key}"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
payload = redis_connector.permissions.payload # The payload must exist
if not payload:
raise ValueError(
error_msg = (
"connector_permission_sync_generator_task: payload invalid or not found"
)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
raise ValueError(error_msg)
if payload.celery_task_id is None:
logger.info(
@@ -432,9 +461,11 @@ def connector_permission_sync_generator_task(
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
error_msg = (
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
task_logger.warning(error_msg)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
try:
@@ -470,11 +501,15 @@ def connector_permission_sync_generator_task(
source_type = cc_pair.connector.source
sync_config = get_source_perm_sync_config(source_type)
if sync_config is None:
logger.error(f"No sync config found for {source_type}")
error_msg = f"No sync config found for {source_type}"
logger.error(error_msg)
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
if sync_config.doc_sync_config is None:
if sync_config.censoring_config:
error_msg = f"Doc sync config is None but censoring config exists for {source_type}"
_fail_doc_permission_sync_attempt(attempt_id, error_msg)
return None
raise ValueError(
@@ -483,6 +518,8 @@ def connector_permission_sync_generator_task(
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
mark_doc_permission_sync_attempt_in_progress(attempt_id, db_session)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
@@ -533,8 +570,9 @@ def connector_permission_sync_generator_task(
)
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.update_db(
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
@@ -542,11 +580,23 @@ def connector_permission_sync_generator_task(
credential_id=cc_pair.credential.id,
task_logger=task_logger,
)
tasks_generated += 1
tasks_generated += result.num_updated
docs_with_errors += result.num_errors
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated} docs_with_errors={docs_with_errors}"
)
complete_doc_permission_sync_attempt(
db_session=db_session,
attempt_id=attempt_id,
total_docs_synced=tasks_generated,
docs_with_permission_errors=docs_with_errors,
)
task_logger.info(
f"Completed doc permission sync attempt {attempt_id}: "
f"{tasks_generated} docs, {docs_with_errors} errors"
)
redis_connector.permissions.generator_complete = tasks_generated
@@ -561,6 +611,11 @@ def connector_permission_sync_generator_task(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
with get_session_with_current_tenant() as db_session:
mark_doc_permission_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)

View File

@@ -49,6 +49,16 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.permission_sync_attempt import complete_external_group_sync_attempt
from onyx.db.permission_sync_attempt import (
create_external_group_sync_attempt,
)
from onyx.db.permission_sync_attempt import (
mark_external_group_sync_attempt_failed,
)
from onyx.db.permission_sync_attempt import (
mark_external_group_sync_attempt_in_progress,
)
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
@@ -70,6 +80,14 @@ logger = setup_logger()
_EXTERNAL_GROUP_BATCH_SIZE = 100
def _fail_external_group_sync_attempt(attempt_id: int, error_msg: str) -> None:
"""Helper to mark an external group sync attempt as failed with an error message."""
with get_session_with_current_tenant() as db_session:
mark_external_group_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
def _get_fence_validation_block_expiration() -> int:
"""
Compute the expiration time for the fence validation block signal.
@@ -449,6 +467,16 @@ def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
attempt_id = create_external_group_sync_attempt(
connector_credential_pair_id=cc_pair_id,
db_session=db_session,
)
logger.info(
f"Created external group sync attempt: {attempt_id} for cc_pair={cc_pair_id}"
)
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -463,11 +491,13 @@ def _perform_external_group_sync(
if sync_config is None:
msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
_fail_external_group_sync_attempt(attempt_id, msg)
raise ValueError(msg)
if sync_config.group_sync_config is None:
msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
_fail_external_group_sync_attempt(attempt_id, msg)
raise ValueError(msg)
ext_group_sync_func = sync_config.group_sync_config.group_sync_func
@@ -477,14 +507,27 @@ def _perform_external_group_sync(
)
mark_old_external_groups_as_stale(db_session, cc_pair_id)
# Mark attempt as in progress
mark_external_group_sync_attempt_in_progress(attempt_id, db_session)
logger.info(f"Marked external group sync attempt {attempt_id} as in progress")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_group_batch: list[ExternalUserGroup] = []
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
external_user_group_batch.append(external_user_group)
# Track progress
total_groups_processed += 1
total_group_memberships_synced += len(external_user_group.user_emails)
seen_users = seen_users.union(external_user_group.user_emails)
if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE:
logger.debug(
f"New external user groups: {external_user_group_batch}"
@@ -506,6 +549,13 @@ def _perform_external_group_sync(
source=cc_pair.connector.source,
)
except Exception as e:
error_msg = format_error_for_logging(e)
# Mark as failed (this also updates progress to show partial progress)
mark_external_group_sync_attempt_failed(
attempt_id, db_session, error_message=error_msg
)
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
@@ -517,6 +567,24 @@ def _perform_external_group_sync(
)
remove_stale_external_groups(db_session, cc_pair_id)
# Calculate total unique users processed
total_users_processed = len(seen_users)
# Complete the sync attempt with final progress
complete_external_group_sync_attempt(
db_session=db_session,
attempt_id=attempt_id,
total_users_processed=total_users_processed,
total_groups_processed=total_groups_processed,
total_group_memberships_synced=total_group_memberships_synced,
errors_encountered=0,
)
logger.info(
f"Completed external group sync attempt {attempt_id}: "
f"{total_groups_processed} groups, {total_users_processed} users, "
f"{total_group_memberships_synced} memberships"
)
mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair)

View File

@@ -0,0 +1,119 @@
import csv
import io
from datetime import datetime
from celery import shared_task
from celery import Task
from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history
from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.background.task_utils import construct_query_history_report_name
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import FileType
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import QueryHistoryType
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tasks import delete_task_with_id
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import mark_task_as_started_with_id
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
bind=True,
trail=False,
)
def export_query_history_task(
self: Task,
*,
start: datetime,
end: datetime,
start_time: datetime,
# Need to include the tenant_id since the TenantAwareTask needs this
tenant_id: str,
) -> None:
if not self.request.id:
raise RuntimeError("No task id defined for this task; cannot identify it")
task_id = self.request.id
stream = io.StringIO()
writer = csv.DictWriter(
stream,
fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()),
)
writer.writeheader()
with get_session_with_current_tenant() as db_session:
try:
mark_task_as_started_with_id(
db_session=db_session,
task_id=task_id,
)
snapshot_generator = fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
)
for snapshot in snapshot_generator:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
writer.writerows(
qa_pair.to_json()
for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot(
snapshot
)
)
except Exception:
logger.exception(f"Failed to export query history with {task_id=}")
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise
report_name = construct_query_history_report_name(task_id)
with get_session_with_current_tenant() as db_session:
try:
stream.seek(0)
get_default_file_store().save_file(
content=stream,
display_name=report_name,
file_origin=FileOrigin.QUERY_HISTORY_CSV,
file_type=FileType.CSV,
file_metadata={
"start": start.isoformat(),
"end": end.isoformat(),
"start_time": start_time.isoformat(),
},
file_id=report_name,
)
delete_task_with_id(
db_session=db_session,
task_id=task_id,
)
except Exception:
logger.exception(
f"Failed to save query history export file; {report_name=}"
)
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise

View File

@@ -73,6 +73,12 @@ def fetch_per_user_query_analytics(
ChatSession.user_id,
)
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
# Include chats that have no explicit feedback instead of dropping them
.join(
ChatMessageFeedback,
ChatMessageFeedback.chat_message_id == ChatMessage.id,
isouter=True,
)
.where(
ChatMessage.time_sent >= start,
)

View File

@@ -0,0 +1,15 @@
from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider
from onyx.feature_flags.interface import FeatureFlagProvider
def get_posthog_feature_flag_provider() -> FeatureFlagProvider:
"""
Get the PostHog feature flag provider instance.
This is the EE implementation that gets loaded by the versioned
implementation loader.
Returns:
PostHogFeatureFlagProvider: The PostHog-based feature flag provider
"""
return PostHogFeatureFlagProvider()

View File

@@ -0,0 +1,54 @@
from typing import Any
from uuid import UUID
from ee.onyx.utils.posthog_client import posthog
from onyx.feature_flags.interface import FeatureFlagProvider
from onyx.utils.logger import setup_logger
logger = setup_logger()
class PostHogFeatureFlagProvider(FeatureFlagProvider):
"""
PostHog-based feature flag provider.
Uses PostHog's feature flag API to determine if features are enabled
for specific users. Only active in multi-tenant mode.
"""
def feature_enabled(
self,
flag_key: str,
user_id: UUID,
user_properties: dict[str, Any] | None = None,
) -> bool:
"""
Check if a feature flag is enabled for a user via PostHog.
Args:
flag_key: The identifier for the feature flag to check
user_id: The unique identifier for the user
user_properties: Optional dictionary of user properties/attributes
that may influence flag evaluation
Returns:
True if the feature is enabled for the user, False otherwise.
"""
try:
posthog.set(
distinct_id=user_id,
properties=user_properties,
)
is_enabled = posthog.feature_enabled(
flag_key,
str(user_id),
person_properties=user_properties,
)
return bool(is_enabled) if is_enabled is not None else False
except Exception as e:
logger.error(
f"Error checking feature flag {flag_key} for user {user_id}: {e}"
)
return False

View File

@@ -0,0 +1,22 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)

View File

@@ -1,27 +1,9 @@
from typing import Any
from posthog import Posthog
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_HOST
from ee.onyx.utils.posthog_client import posthog
from onyx.utils.logger import setup_logger
logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=True,
on_error=posthog_on_error,
)
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:

View File

@@ -100,9 +100,14 @@ class IterationAnswer(BaseModel):
response_type: str | None = None
data: dict | list | str | int | float | bool | None = None
file_ids: list[str] | None = None
# TODO: This is not ideal, but we'll can rework the schema
# for deep research later
is_web_fetch: bool = False
# for image generation step-types
generated_images: list[GeneratedImage] | None = None
# for multi-query search tools (v2 web search and internal search)
# TODO: Clean this up to be more flexible to tools
queries: list[str] | None = None
class AggregatedDRContext(BaseModel):

View File

@@ -74,6 +74,7 @@ from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
from onyx.prompts.dr_prompts import REPEAT_PROMPT
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
from onyx.prompts.prompt_template import PromptTemplate
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -120,7 +121,9 @@ def _get_available_tools(
else:
include_kg = False
tool_dict: dict[int, Tool] = {tool.id: tool for tool in get_tools(db_session)}
tool_dict: dict[int, Tool] = {
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
}
for tool in graph_config.tooling.tools:
@@ -488,6 +491,7 @@ def clarifier(
+ PROJECT_INSTRUCTIONS_SEPARATOR
+ graph_config.inputs.project_instructions
)
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
chat_history_string = (
get_chat_history_string(

View File

@@ -199,6 +199,7 @@ def save_iteration(
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)

View File

@@ -180,6 +180,7 @@ def save_iteration(
else None
),
additional_data=iteration_answer.additional_data,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)

View File

@@ -2,30 +2,28 @@ from exa_py import Exa
from exa_py.api import HighlightsContentsOptions
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.utils.retry_wrapper import retry_builder
# TODO Dependency inject for testing
class ExaClient(InternetSearchProvider):
class ExaClient(WebSearchProvider):
def __init__(self, api_key: str | None = EXA_API_KEY) -> None:
self.exa = Exa(api_key=api_key)
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="fast",
livecrawl="never",
type="auto",
highlights=HighlightsContentsOptions(
num_sentences=2,
highlights_per_url=1,
@@ -34,7 +32,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetSearchResult(
WebSearchResult(
title=result.title or "",
link=result.url,
snippet=result.highlights[0] if result.highlights else "",
@@ -49,7 +47,7 @@ class ExaClient(InternetSearchProvider):
]
@retry_builder(tries=3, delay=1, backoff=2)
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
response = self.exa.get_contents(
urls=urls,
text=True,
@@ -57,7 +55,7 @@ class ExaClient(InternetSearchProvider):
)
return [
InternetContent(
WebContent(
title=result.title or "",
link=result.url,
full_content=result.text or "",

View File

@@ -4,13 +4,13 @@ from concurrent.futures import ThreadPoolExecutor
import requests
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.chat_configs import SERPER_API_KEY
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
@@ -20,7 +20,7 @@ SERPER_SEARCH_URL = "https://google.serper.dev/search"
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
class SerperClient(InternetSearchProvider):
class SerperClient(WebSearchProvider):
def __init__(self, api_key: str | None = SERPER_API_KEY) -> None:
self.headers = {
"X-API-KEY": api_key,
@@ -28,7 +28,7 @@ class SerperClient(InternetSearchProvider):
}
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
payload = {
"q": query,
}
@@ -45,7 +45,7 @@ class SerperClient(InternetSearchProvider):
organic_results = results["organic"]
return [
InternetSearchResult(
WebSearchResult(
title=result["title"],
link=result["link"],
snippet=result["snippet"],
@@ -55,17 +55,17 @@ class SerperClient(InternetSearchProvider):
for result in organic_results
]
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
if not urls:
return []
# Serper can responds with 500s regularly. We want to retry,
# but in the event of failure, return an unsuccesful scrape.
def safe_get_webpage_content(url: str) -> InternetContent:
def safe_get_webpage_content(url: str) -> WebContent:
try:
return self._get_webpage_content(url)
except Exception:
return InternetContent(
return WebContent(
title="",
link=url,
full_content="",
@@ -77,7 +77,7 @@ class SerperClient(InternetSearchProvider):
return list(e.map(safe_get_webpage_content, urls))
@retry_builder(tries=3, delay=1, backoff=2)
def _get_webpage_content(self, url: str) -> InternetContent:
def _get_webpage_content(self, url: str) -> WebContent:
payload = {
"url": url,
}
@@ -90,7 +90,7 @@ class SerperClient(InternetSearchProvider):
# 400 returned when serper cannot scrape
if response.status_code == 400:
return InternetContent(
return WebContent(
title="",
link=url,
full_content="",
@@ -122,7 +122,7 @@ class SerperClient(InternetSearchProvider):
except Exception:
published_date = None
return InternetContent(
return WebContent(
title=title or "",
link=response_url,
full_content=text or "",

View File

@@ -7,7 +7,7 @@ from langsmith import traceable
from onyx.agents.agent_search.dr.models import WebSearchAnswer
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
@@ -75,15 +75,15 @@ def web_search(
raise ValueError("No internet search provider found")
@traceable(name="Search Provider API Call")
def _search(search_query: str) -> list[InternetSearchResult]:
search_results: list[InternetSearchResult] = []
def _search(search_query: str) -> list[WebSearchResult]:
search_results: list[WebSearchResult] = []
try:
search_results = provider.search(search_query)
except Exception as e:
logger.error(f"Error performing search: {e}")
return search_results
search_results: list[InternetSearchResult] = _search(search_query)
search_results: list[WebSearchResult] = _search(search_query)
search_results_text = "\n\n".join(
[
f"{i}. {result.title}\n URL: {result.link}\n"

View File

@@ -4,7 +4,7 @@ from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.states import (
InternetSearchInput,
@@ -23,7 +23,7 @@ def dedup_urls(
writer: StreamWriter = lambda _: None,
) -> InternetSearchInput:
branch_questions_to_urls: dict[str, list[str]] = defaultdict(list)
unique_results_by_link: dict[str, InternetSearchResult] = {}
unique_results_by_link: dict[str, WebSearchResult] = {}
for query, result in state.results_to_open:
branch_questions_to_urls[query].append(result.link)
if result.link not in unique_results_by_link:

View File

@@ -13,7 +13,7 @@ class ProviderType(Enum):
EXA = "exa"
class InternetSearchResult(BaseModel):
class WebSearchResult(BaseModel):
title: str
link: str
author: str | None = None
@@ -21,7 +21,7 @@ class InternetSearchResult(BaseModel):
snippet: str | None = None
class InternetContent(BaseModel):
class WebContent(BaseModel):
title: str
link: str
full_content: str
@@ -29,11 +29,11 @@ class InternetContent(BaseModel):
scrape_successful: bool = True
class InternetSearchProvider(ABC):
class WebSearchProvider(ABC):
@abstractmethod
def search(self, query: str) -> list[InternetSearchResult]:
def search(self, query: str) -> list[WebSearchResult]:
pass
@abstractmethod
def contents(self, urls: list[str]) -> list[InternetContent]:
def contents(self, urls: list[str]) -> list[WebContent]:
pass

View File

@@ -5,13 +5,13 @@ from onyx.agents.agent_search.dr.sub_agents.web_search.clients.serper_client imp
SerperClient,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchProvider,
WebSearchProvider,
)
from onyx.configs.chat_configs import EXA_API_KEY
from onyx.configs.chat_configs import SERPER_API_KEY
def get_default_provider() -> InternetSearchProvider | None:
def get_default_provider() -> WebSearchProvider | None:
if EXA_API_KEY:
return ExaClient()
if SERPER_API_KEY:

View File

@@ -4,13 +4,13 @@ from typing import Annotated
from onyx.agents.agent_search.dr.states import LoggerUpdate
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.context.search.models import InferenceSection
class InternetSearchInput(SubAgentInput):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
parallelization_nr: int = 0
branch_question: Annotated[str, lambda x, y: y] = ""
branch_questions_to_urls: Annotated[dict[str, list[str]], lambda x, y: y] = {}
@@ -18,7 +18,7 @@ class InternetSearchInput(SubAgentInput):
class InternetSearchUpdate(LoggerUpdate):
results_to_open: Annotated[list[tuple[str, InternetSearchResult]], add] = []
results_to_open: Annotated[list[tuple[str, WebSearchResult]], add] = []
class FetchInput(SubAgentInput):

View File

@@ -1,8 +1,8 @@
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetContent,
WebContent,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.models import (
InternetSearchResult,
WebSearchResult,
)
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
@@ -17,7 +17,7 @@ def truncate_search_result_content(content: str, max_chars: int = 10000) -> str:
def dummy_inference_section_from_internet_content(
result: InternetContent,
result: WebContent,
) -> InferenceSection:
truncated_content = truncate_search_result_content(result.full_content)
return InferenceSection(
@@ -48,7 +48,7 @@ def dummy_inference_section_from_internet_content(
def dummy_inference_section_from_internet_search_result(
result: InternetSearchResult,
result: WebSearchResult,
) -> InferenceSection:
return InferenceSection(
center_chunk=InferenceChunk(

View File

@@ -54,6 +54,8 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import nulls_last
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
@@ -103,6 +105,7 @@ from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
@@ -324,8 +327,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
user_create.role = UserRole.ADMIN
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
user = await super().create(
user_create, safe=safe, request=request
) # type: ignore
user_created = True
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
@@ -351,11 +358,42 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
role=user_create.role,
)
user = await self.update(user_update, user)
if user_created:
await self._assign_default_pinned_assistants(user, db_session)
remove_user_from_invited_users(user_create.email)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def _assign_default_pinned_assistants(
self, user: User, db_session: AsyncSession
) -> None:
if user.pinned_assistants is not None:
return
result = await db_session.execute(
select(Persona.id)
.where(
Persona.is_default_persona.is_(True),
Persona.is_public.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),
)
.order_by(
nulls_last(Persona.display_priority.asc()),
Persona.id.asc(),
)
)
default_persona_ids = list(result.scalars().all())
if not default_persona_ids:
return
await self.user_db.update(
user,
{"pinned_assistants": default_persona_ids},
)
user.pinned_assistants = default_persona_ids
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to configurable security policy (defined via environment variables)
if len(password) < PASSWORD_MIN_LENGTH:
@@ -476,6 +514,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self._assign_default_pinned_assistants(user, db_session)
await self.on_after_register(user, request)
else:
@@ -1040,7 +1079,10 @@ async def optional_user(
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
try:
hashed_api_key = get_hashed_api_key_from_request(request)
except ValueError:
hashed_api_key = None
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)

View File

@@ -0,0 +1,137 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
]
)

View File

@@ -0,0 +1,23 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -15,6 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = 4
worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -1,4 +1,5 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
@@ -16,6 +17,6 @@ task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -42,6 +42,12 @@ from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.permission_sync_attempt import (
delete_doc_permission_sync_attempts__no_commit,
)
from onyx.db.permission_sync_attempt import (
delete_external_group_permission_sync_attempts__no_commit,
)
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
@@ -441,6 +447,16 @@ def monitor_connector_deletion_taskset(
cc_pair_id=cc_pair_id,
)
# permission sync attempts
delete_doc_permission_sync_attempts__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
delete_external_group_permission_sync_attempts__no_commit(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,

View File

@@ -0,0 +1,10 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -2,11 +2,13 @@ import re
import time
import traceback
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import cast
from typing import Protocol
from uuid import UUID
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.chat.answer import Answer
@@ -25,12 +27,16 @@ from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.packet_proccessing.process_streamed_packets import (
process_streamed_packets,
)
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import (
default_build_system_message_v2,
)
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message_v2
from onyx.chat.turn import fast_chat_turn
from onyx.chat.turn.infra.emitter import get_default_emitter
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.chat.user_files.parse_user_files import parse_user_files
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -69,6 +75,8 @@ from onyx.db.projects import get_project_instructions
from onyx.db.projects import get_user_files_from_project
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.feature_flags.factory import get_default_feature_flag_provider
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import build_frontend_file_url
@@ -81,6 +89,7 @@ from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.redis.redis_pool import get_redis_client
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -88,6 +97,7 @@ from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.adapter_v1_to_v2 import tools_to_function_tools
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
@@ -355,14 +365,12 @@ def stream_chat_message_objects(
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
persona = _get_persona_for_chat_session(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
default_persona=chat_session.persona,
)
# TODO: remove once we have an endpoint for this stuff
process_kg_commands(new_msg_req.message, persona.name, tenant_id, db_session)
@@ -732,15 +740,36 @@ def stream_chat_message_objects(
and (file.file_id not in project_file_ids)
]
)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
feature_flag_provider = get_default_feature_flag_provider()
simple_agent_framework_enabled = (
feature_flag_provider.feature_enabled_for_user_tenant(
flag_key=SIMPLE_AGENT_FRAMEWORK,
user=user,
tenant_id=tenant_id,
)
and not new_msg_req.use_agentic_search
)
prompt_user_message = (
default_build_user_message_v2(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
)
if simple_agent_framework_enabled
else default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
)
)
system_message = (
default_build_system_message_v2(prompt_config, llm.config)
if simple_agent_framework_enabled
else default_build_system_message(prompt_config, llm.config)
)
prompt_builder = AnswerPromptBuilder(
user_message=prompt_user_message,
system_message=system_message,
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
@@ -782,11 +811,21 @@ def stream_chat_message_objects(
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
project_instructions=project_instructions,
)
if simple_agent_framework_enabled:
yield from _fast_message_stream(
answer,
tools,
db_session,
get_redis_client(),
chat_session_id,
reserved_message_id,
)
else:
from onyx.chat.packet_proccessing import process_streamed_packets
# Process streamed packets using the new packet processing module
yield from process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
yield from process_streamed_packets.process_streamed_packets(
answer_processed_output=answer.processed_streamed_output,
)
except ValueError as e:
logger.exception("Failed to process chat message.")
@@ -823,6 +862,59 @@ def stream_chat_message_objects(
return
# TODO: Refactor this to live somewhere else
def _fast_message_stream(
answer: Answer,
tools: list[Tool],
db_session: Session,
redis_client: Redis,
chat_session_id: UUID,
reserved_message_id: int,
) -> Generator[Packet, None, None]:
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.llm.litellm_singleton import LitellmModel
image_generation_tool_instance = None
okta_profile_tool_instance = None
for tool in tools:
if isinstance(tool, ImageGenerationTool):
image_generation_tool_instance = tool
elif isinstance(tool, OktaProfileTool):
okta_profile_tool_instance = tool
converted_message_history = [
PreviousMessage.from_langchain_msg(message, 0).to_agent_sdk_msg()
for message in answer.graph_inputs.prompt_builder.build()
]
emitter = get_default_emitter()
return fast_chat_turn.fast_chat_turn(
messages=converted_message_history,
# TODO: Maybe we can use some DI framework here?
dependencies=ChatTurnDependencies(
llm_model=LitellmModel(
model=answer.graph_tooling.primary_llm.config.model_name,
base_url=answer.graph_tooling.primary_llm.config.api_base,
api_key=answer.graph_tooling.primary_llm.config.api_key,
),
llm=answer.graph_tooling.primary_llm,
tools=tools_to_function_tools(tools),
search_pipeline=answer.graph_tooling.search_tool,
image_generation_tool=image_generation_tool_instance,
okta_profile_tool=okta_profile_tool_instance,
db_session=db_session,
redis_client=redis_client,
emitter=emitter,
),
chat_session_id=chat_session_id,
message_id=reserved_message_id,
research_type=answer.graph_config.behavior.research_type,
)
@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,

View File

@@ -21,8 +21,10 @@ from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT_V2
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_company_awareness
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
@@ -31,6 +33,33 @@ from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
def default_build_system_message_v2(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
system_prompt += REQUIRE_CITATION_STATEMENT_V2
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
return SystemMessage(content=tag_handled_prompt)
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
@@ -52,9 +81,29 @@ def default_build_system_message(
if not tag_handled_prompt:
return None
tag_handled_prompt = handle_company_awareness(tag_handled_prompt)
return SystemMessage(content=tag_handled_prompt)
def default_build_user_message_v2(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
) -> HumanMessage:
user_prompt = user_query
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=(
build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
)
)
return user_msg
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,

View File

@@ -0,0 +1,56 @@
from uuid import UUID
from redis.client import Redis
from shared_configs.contextvars import get_current_tenant_id
# Redis key prefixes for chat session stop signals
PREFIX = "chatsessionstop"
FENCE_PREFIX = f"{PREFIX}_fence"
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
"""
Set or clear the stop signal fence for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use
value: True to set the fence (stop signal), False to clear it
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
if not value:
redis_client.delete(fence_key)
return
redis_client.set(fence_key, 0)
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
"""
Check if the chat session should continue (not stopped).
Args:
chat_session_id: The UUID of the chat session to check
redis_client: Redis client to use for checking the stop signal
Returns:
True if the session should continue, False if it should stop
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
return not bool(redis_client.exists(fence_key))
def reset_cancel_status(chat_session_id: UUID, redis_client: Redis) -> None:
"""
Clear the stop signal for a chat session.
Args:
chat_session_id: The UUID of the chat session
redis_client: Redis client to use
"""
tenant_id = get_current_tenant_id()
fence_key = f"{FENCE_PREFIX}_{tenant_id}_{chat_session_id}"
redis_client.delete(fence_key)

View File

@@ -0,0 +1 @@
# Turn module for chat functionality

View File

@@ -0,0 +1,258 @@
from typing import cast
from uuid import UUID
from agents import Agent
from agents import ModelSettings
from agents import RawResponsesStreamEvent
from agents import StopAtTools
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.stop_signal_checker import is_connected
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.chat.stream_processing.citation_processing import CitationProcessor
from onyx.chat.turn.infra.chat_turn_event_stream import unified_event_stream
from onyx.chat.turn.infra.session_sink import extract_final_answer_from_packets
from onyx.chat.turn.infra.session_sink import save_iteration
from onyx.chat.turn.infra.sync_agent_stream_adapter import SyncAgentStream
from onyx.chat.turn.models import AgentToolType
from onyx.chat.turn.models import ChatTurnContext
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.context.search.models import InferenceSection
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketObj
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
def _fast_chat_turn_core(
messages: list[dict],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
research_type: ResearchType,
# Dependency injectable arguments for testing
starter_global_iteration_responses: list[IterationAnswer] | None = None,
starter_cited_documents: list[InferenceSection] | None = None,
) -> None:
"""Core fast chat turn logic that allows overriding global_iteration_responses for testing.
Args:
messages: List of chat messages
dependencies: Chat turn dependencies
chat_session_id: Chat session ID
message_id: Message ID
research_type: Research type
global_iteration_responses: Optional list of iteration answers to inject for testing
cited_documents: Optional list of cited documents to inject for testing
"""
reset_cancel_status(
chat_session_id,
dependencies.redis_client,
)
ctx = ChatTurnContext(
run_dependencies=dependencies,
aggregated_context=AggregatedDRContext(
context="context",
cited_documents=starter_cited_documents or [],
is_internet_marker_dict={},
global_iteration_responses=starter_global_iteration_responses or [],
),
iteration_instructions=[],
chat_session_id=chat_session_id,
message_id=message_id,
research_type=research_type,
)
agent = Agent(
name="Assistant",
model=dependencies.llm_model,
tools=cast(list[AgentToolType], dependencies.tools),
model_settings=ModelSettings(
temperature=dependencies.llm.config.temperature,
include_usage=True,
),
tool_use_behavior=StopAtTools(stop_at_tool_names=[image_generation_tool.name]),
)
# By default, the agent can only take 10 turns. For our use case, it should be higher.
max_turns = 25
agent_stream: SyncAgentStream = SyncAgentStream(
agent=agent,
input=messages,
context=ctx,
max_turns=max_turns,
)
for ev in agent_stream:
connected = is_connected(
chat_session_id,
dependencies.redis_client,
)
if not connected:
_emit_clean_up_packets(dependencies, ctx)
agent_stream.cancel()
break
obj = _default_packet_translation(ev, ctx)
if obj:
dependencies.emitter.emit(Packet(ind=ctx.current_run_step, obj=obj))
final_answer = extract_final_answer_from_packets(
dependencies.emitter.packet_history
)
all_cited_documents = []
if ctx.aggregated_context.global_iteration_responses:
context_docs = _gather_context_docs_from_iteration_answers(
ctx.aggregated_context.global_iteration_responses
)
all_cited_documents = context_docs
if context_docs and final_answer:
_process_citations_for_final_answer(
final_answer=final_answer,
context_docs=context_docs,
dependencies=dependencies,
ctx=ctx,
)
save_iteration(
db_session=dependencies.db_session,
message_id=message_id,
chat_session_id=chat_session_id,
research_type=research_type,
ctx=ctx,
final_answer=final_answer,
all_cited_documents=all_cited_documents,
)
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=OverallStop(type="stop"))
)
@unified_event_stream
def fast_chat_turn(
messages: list[dict],
dependencies: ChatTurnDependencies,
chat_session_id: UUID,
message_id: int,
research_type: ResearchType,
) -> None:
"""Main fast chat turn function that calls the core logic with default parameters."""
_fast_chat_turn_core(
messages,
dependencies,
chat_session_id,
message_id,
research_type,
starter_global_iteration_responses=None,
)
# TODO: Maybe in general there's a cleaner way to handle cancellation in the middle of a tool call?
def _emit_clean_up_packets(
dependencies: ChatTurnDependencies, ctx: ChatTurnContext
) -> None:
if not (
dependencies.emitter.packet_history
and dependencies.emitter.packet_history[-1].obj.type == "message_delta"
):
dependencies.emitter.emit(
Packet(
ind=ctx.current_run_step,
obj=MessageStart(
type="message_start", content="Cancelled", final_documents=None
),
)
)
dependencies.emitter.emit(
Packet(ind=ctx.current_run_step, obj=SectionEnd(type="section_end"))
)
def _gather_context_docs_from_iteration_answers(
iteration_answers: list[IterationAnswer],
) -> list[InferenceSection]:
"""Gather cited documents from iteration answers for citation processing."""
context_docs: list[InferenceSection] = []
for iteration_answer in iteration_answers:
# Extract cited documents from this iteration
for inference_section in iteration_answer.cited_documents.values():
# Avoid duplicates by checking document_id
if not any(
doc.center_chunk.document_id
== inference_section.center_chunk.document_id
for doc in context_docs
):
context_docs.append(inference_section)
return context_docs
def _process_citations_for_final_answer(
final_answer: str,
context_docs: list[InferenceSection],
dependencies: ChatTurnDependencies,
ctx: ChatTurnContext,
) -> None:
index = ctx.current_run_step + 1
"""Process citations in the final answer and emit citation events."""
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
# Convert InferenceSection objects to LlmDoc objects for citation processing
llm_docs = [llm_doc_from_inference_section(section) for section in context_docs]
# Create document ID to rank mappings (simple 1-based indexing)
final_doc_id_to_rank_map = DocumentIdOrderMapping(
order_mapping={doc.document_id: i + 1 for i, doc in enumerate(llm_docs)}
)
display_doc_id_to_rank_map = final_doc_id_to_rank_map # Same mapping for display
# Initialize citation processor
citation_processor = CitationProcessor(
context_docs=llm_docs,
final_doc_id_to_rank_map=final_doc_id_to_rank_map,
display_doc_id_to_rank_map=display_doc_id_to_rank_map,
)
# Process the final answer through citation processor
collected_citations: list = []
for response_part in citation_processor.process_token(final_answer):
if hasattr(response_part, "citation_num"): # It's a CitationInfo
collected_citations.append(response_part)
# Emit citation events if we found any citations
if collected_citations:
dependencies.emitter.emit(Packet(ind=index, obj=CitationStart()))
dependencies.emitter.emit(
Packet(
ind=index,
obj=CitationDelta(citations=collected_citations), # type: ignore[arg-type]
)
)
dependencies.emitter.emit(Packet(ind=index, obj=SectionEnd(type="section_end")))
ctx.current_run_step = index
def _default_packet_translation(ev: object, ctx: ChatTurnContext) -> PacketObj | None:
if isinstance(ev, RawResponsesStreamEvent):
# TODO: might need some variation here for different types of models
# OpenAI packet translator
obj: PacketObj | None = None
if ev.data.type == "response.content_part.added":
retrieved_search_docs = convert_inference_sections_to_search_docs(
ctx.aggregated_context.cited_documents
)
obj = MessageStart(
type="message_start", content="", final_documents=retrieved_search_docs
)
elif ev.data.type == "response.output_text.delta":
obj = MessageDelta(type="message_delta", content=ev.data.delta)
elif ev.data.type == "response.content_part.done":
obj = SectionEnd(type="section_end")
return obj
return None

View File

@@ -0,0 +1 @@
# Infrastructure module for chat turn orchestration

View File

@@ -0,0 +1,57 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Dict
from typing import List
from onyx.chat.turn.models import ChatTurnDependencies
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import wait_on_background
def unified_event_stream(
turn_func: Callable[..., None],
) -> Callable[..., Generator[Packet, None]]:
"""
Decorator that wraps a turn_func to provide event streaming capabilities.
Usage:
@unified_event_stream
def my_turn_func(messages, dependencies, *args, **kwargs):
# Your turn logic here
pass
Then call it like:
generator = my_turn_func(messages, dependencies, *args, **kwargs)
"""
def wrapper(
messages: List[Dict[str, Any]],
dependencies: ChatTurnDependencies,
*args: Any,
**kwargs: Any
) -> Generator[Packet, None]:
def run_with_exception_capture() -> None:
try:
turn_func(messages, dependencies, *args, **kwargs)
except Exception as e:
dependencies.emitter.emit(
Packet(ind=0, obj=PacketException(type="error", exception=e))
)
thread = run_in_background(run_with_exception_capture)
while True:
pkt: Packet = dependencies.emitter.bus.get()
if pkt.obj == OverallStop(type="stop"):
yield pkt
break
elif isinstance(pkt.obj, PacketException):
raise pkt.obj.exception
else:
yield pkt
wait_on_background(thread)
return wrapper

View File

@@ -0,0 +1,21 @@
from queue import Queue
from onyx.server.query_and_chat.streaming_models import Packet
class Emitter:
"""Use this inside tools to emit arbitrary UI progress."""
def __init__(self, bus: Queue):
self.bus = bus
self.packet_history: list[Packet] = []
def emit(self, packet: Packet) -> None:
self.bus.put(packet)
self.packet_history.append(packet)
def get_default_emitter() -> Emitter:
bus: Queue[Packet] = Queue()
emitter = Emitter(bus)
return emitter

View File

@@ -0,0 +1,170 @@
# TODO: Figure out a way to persist information is robust to cancellation,
# modular so easily testable in unit tests and evals [likely injecting some higher
# level session manager and span sink], potentially has some robustness off the critical path,
# and promotes clean separation of concerns.
import re
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
GeneratedImageFullResult,
)
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.turn.models import ChatTurnContext
from onyx.context.search.models import InferenceSection
from onyx.db.chat import create_search_doc_from_inference_section
from onyx.db.chat import update_db_session_with_messages
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ResearchAgentIteration
from onyx.db.models import ResearchAgentIterationSubStep
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.streaming_models import MessageDelta
from onyx.server.query_and_chat.streaming_models import MessageStart
from onyx.server.query_and_chat.streaming_models import Packet
def save_iteration(
db_session: Session,
message_id: int,
chat_session_id: UUID,
research_type: ResearchType,
ctx: ChatTurnContext,
final_answer: str,
all_cited_documents: list[InferenceSection],
) -> None:
# first, insert the search_docs
is_internet_marker_dict: dict[str, bool] = {}
search_docs = [
create_search_doc_from_inference_section(
inference_section=inference_section,
is_internet=is_internet_marker_dict.get(
inference_section.center_chunk.document_id, False
), # TODO: revisit
db_session=db_session,
commit=False,
)
for inference_section in all_cited_documents
]
# then, map_search_docs to message
_insert_chat_message_search_doc_pair(
message_id, [search_doc.id for search_doc in search_docs], db_session
)
# lastly, insert the citations
citation_dict: dict[int, int] = {}
cited_doc_nrs = _extract_citation_numbers(final_answer)
if search_docs:
for cited_doc_nr in cited_doc_nrs:
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
llm_tokenizer = get_tokenizer(
model_name=ctx.run_dependencies.llm.config.model_name,
provider_type=ctx.run_dependencies.llm.config.model_provider,
)
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
# Update the chat message and its parent message in database
update_db_session_with_messages(
db_session=db_session,
chat_message_id=message_id,
chat_session_id=chat_session_id,
is_agentic=research_type == ResearchType.DEEP,
message=final_answer,
citations=citation_dict,
research_type=research_type,
research_plan={},
final_documents=search_docs,
update_parent_message=True,
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
token_count=num_tokens,
)
# TODO: I don't think this is the ideal schema for all use cases
# find a better schema to store tool and reasoning calls
for iteration_preparation in ctx.iteration_instructions:
research_agent_iteration_step = ResearchAgentIteration(
primary_question_id=message_id,
reasoning=iteration_preparation.reasoning,
purpose=iteration_preparation.purpose,
iteration_nr=iteration_preparation.iteration_nr,
)
db_session.add(research_agent_iteration_step)
for iteration_answer in ctx.aggregated_context.global_iteration_responses:
retrieved_search_docs = convert_inference_sections_to_search_docs(
list(iteration_answer.cited_documents.values())
)
# Convert SavedSearchDoc objects to JSON-serializable format
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
primary_question_id=message_id,
iteration_nr=iteration_answer.iteration_nr,
iteration_sub_step_nr=iteration_answer.parallelization_nr,
sub_step_instructions=iteration_answer.question,
sub_step_tool_id=iteration_answer.tool_id,
sub_answer=iteration_answer.answer,
reasoning=iteration_answer.reasoning,
claims=iteration_answer.claims,
cited_doc_results=serialized_search_docs,
generated_images=(
GeneratedImageFullResult(images=iteration_answer.generated_images)
if iteration_answer.generated_images
else None
),
additional_data=iteration_answer.additional_data,
is_web_fetch=iteration_answer.is_web_fetch,
queries=iteration_answer.queries,
)
db_session.add(research_agent_iteration_sub_step)
db_session.commit()
def _insert_chat_message_search_doc_pair(
message_id: int, search_doc_ids: list[int], db_session: Session
) -> None:
"""
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
Args:
message_id: The ID of the chat message
search_doc_id: The ID of the search document
db_session: The database session
"""
for search_doc_id in search_doc_ids:
chat_message_search_doc = ChatMessage__SearchDoc(
chat_message_id=message_id, search_doc_id=search_doc_id
)
db_session.add(chat_message_search_doc)
def _extract_citation_numbers(text: str) -> list[int]:
"""
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
Returns a list of all unique citation numbers found.
"""
# Pattern to match [[number]] or [[number1, number2, ...]]
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
matches = re.findall(pattern, text)
cited_numbers = []
for match in matches:
# Split by comma and extract all numbers
numbers = [int(num.strip()) for num in match.split(",")]
cited_numbers.extend(numbers)
return list(set(cited_numbers)) # Return unique numbers
def extract_final_answer_from_packets(packet_history: list[Packet]) -> str:
"""Extract the final answer by concatenating all MessageDelta content."""
final_answer = ""
for packet in packet_history:
if isinstance(packet.obj, MessageDelta) or isinstance(packet.obj, MessageStart):
final_answer += packet.obj.content
return final_answer

View File

@@ -0,0 +1,177 @@
import asyncio
import queue
import threading
from collections.abc import Iterator
from typing import Generic
from typing import Optional
from typing import TypeVar
from agents import Agent
from agents import RunResultStreaming
from agents.run import Runner
from onyx.chat.turn.models import ChatTurnContext
from onyx.utils.threadpool_concurrency import run_in_background
T = TypeVar("T")
class SyncAgentStream(Generic[T]):
"""
Convert an async streamed run into a sync iterator with cooperative cancellation.
Runs the Agent in a background thread.
Usage:
adapter = SyncStreamAdapter(
agent=agent,
input=input,
context=context,
max_turns=100,
queue_maxsize=0, # optional backpressure
)
for ev in adapter: # sync iteration
...
# or cancel from elsewhere:
adapter.cancel()
"""
_SENTINEL = object()
def __init__(
self,
*,
agent: Agent,
input: list[dict],
context: ChatTurnContext,
max_turns: int = 100,
queue_maxsize: int = 0,
) -> None:
self._agent = agent
self._input = input
self._context = context
self._max_turns = max_turns
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread: Optional[threading.Thread] = None
self._streamed: RunResultStreaming | None = None
self._exc: Optional[BaseException] = None
self._cancel_requested = threading.Event()
self._started = threading.Event()
self._done = threading.Event()
self._start_thread()
# ---------- public sync API ----------
def __iter__(self) -> Iterator[T]:
try:
while True:
item = self._q.get()
if item is self._SENTINEL:
# If the consumer thread raised, surface it now
if self._exc is not None:
raise self._exc
# Normal completion
return
yield item # type: ignore[misc,return-value]
finally:
# Ensure we fully clean up whether we exited due to exception,
# StopIteration, or external cancel.
self.close()
def cancel(self) -> bool:
"""
Cooperatively cancel the underlying streamed run and shut down.
Safe to call multiple times and from any thread.
"""
self._cancel_requested.set()
loop = self._loop
streamed = self._streamed
if loop is not None and streamed is not None and not self._done.is_set():
loop.call_soon_threadsafe(streamed.cancel)
return True
return False
def close(self, *, wait: bool = True) -> None:
"""Idempotent shutdown."""
self.cancel()
# ask the loop to stop if it's still running
loop = self._loop
if loop is not None and loop.is_running():
try:
loop.call_soon_threadsafe(loop.stop)
except Exception:
pass
# join the thread
if wait and self._thread is not None and self._thread.is_alive():
self._thread.join(timeout=5.0)
# ---------- internals ----------
def _start_thread(self) -> None:
t = run_in_background(self._thread_main)
self._thread = t
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
self._started.wait(timeout=1.0)
def _thread_main(self) -> None:
loop = asyncio.new_event_loop()
self._loop = loop
asyncio.set_event_loop(loop)
async def worker() -> None:
try:
# Start the streamed run inside the loop thread
self._streamed = Runner.run_streamed(
self._agent,
self._input, # type: ignore[arg-type]
context=self._context,
max_turns=self._max_turns,
)
# If cancel was requested before we created _streamed, honor it now
if self._cancel_requested.is_set():
await self._streamed.cancel() # type: ignore[func-returns-value]
# Consume async events and forward into the thread-safe queue
async for ev in self._streamed.stream_events():
# Early exit if a late cancel arrives
if self._cancel_requested.is_set():
# Try to cancel gracefully; don't break until cancel takes effect
try:
await self._streamed.cancel() # type: ignore[func-returns-value]
except Exception:
pass
break
# This put() may block if queue_maxsize > 0 (backpressure)
self._q.put(ev)
except BaseException as e:
# Save exception to surface on the sync iterator side
self._exc = e
finally:
# Signal end-of-stream
self._q.put(self._SENTINEL)
self._done.set()
# Mark started and run the worker to completion
self._started.set()
try:
loop.run_until_complete(worker())
finally:
try:
# Drain pending tasks/callbacks safely
pending = asyncio.all_tasks(loop=loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass
finally:
loop.close()
self._loop = None
self._streamed = None

View File

@@ -0,0 +1,70 @@
import dataclasses
from collections.abc import Sequence
from dataclasses import dataclass
from uuid import UUID
from agents import CodeInterpreterTool
from agents import ComputerTool
from agents import FileSearchTool
from agents import FunctionTool
from agents import HostedMCPTool
from agents import ImageGenerationTool as AgentsImageGenerationTool
from agents import LocalShellTool
from agents import Model
from agents import WebSearchTool
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.agents.agent_search.dr.models import AggregatedDRContext
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.turn.infra.emitter import Emitter
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
# Type alias for all tool types accepted by the Agent
AgentToolType = (
FunctionTool
| FileSearchTool
| WebSearchTool
| ComputerTool
| HostedMCPTool
| LocalShellTool
| AgentsImageGenerationTool
| CodeInterpreterTool
)
@dataclass
class ChatTurnDependencies:
llm_model: Model
llm: LLM
db_session: Session
tools: Sequence[FunctionTool]
redis_client: Redis
emitter: Emitter
search_pipeline: SearchTool | None = None
image_generation_tool: ImageGenerationTool | None = None
okta_profile_tool: OktaProfileTool | None = None
@dataclass
class ChatTurnContext:
"""Context class to hold search tool and other dependencies"""
chat_session_id: UUID
message_id: int
research_type: ResearchType
run_dependencies: ChatTurnDependencies
aggregated_context: AggregatedDRContext
current_run_step: int = 0
iteration_instructions: list[IterationInstructions] = dataclasses.field(
default_factory=list
)
web_fetch_results: list[dict] = dataclasses.field(default_factory=list)

View File

@@ -349,10 +349,6 @@ except ValueError:
CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 4
)
CELERY_WORKER_PRIMARY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4
)
@@ -360,18 +356,30 @@ CELERY_WORKER_PRIMARY_CONCURRENCY = int(
CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT = 4
try:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get(
"CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY",
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT,
)
)
except ValueError:
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = (
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY_DEFAULT
)
# Consolidated background worker (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
)
CELERY_WORKER_MONITORING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
)
CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY") or 2
)
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 8192

View File

@@ -72,12 +72,13 @@ POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
"celery_worker_user_file_processing"
)

View File

@@ -61,21 +61,16 @@ _BASE_EMBEDDING_MODELS = [
dim=1536,
index_name="danswer_chunk_text_embedding_3_small",
),
_BaseEmbeddingModel(
name="google/gemini-embedding-001",
dim=3072,
index_name="danswer_chunk_google_gemini_embedding_001",
),
_BaseEmbeddingModel(
name="google/text-embedding-005",
dim=768,
index_name="danswer_chunk_google_text_embedding_005",
),
_BaseEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_google_textembedding_gecko_003",
),
_BaseEmbeddingModel(
name="google/textembedding-gecko@003",
dim=768,
index_name="danswer_chunk_textembedding_gecko_003",
),
_BaseEmbeddingModel(
name="voyage/voyage-large-2-instruct",
dim=1024,

View File

@@ -344,6 +344,9 @@ class GoogleDriveConnector(
def _get_all_drives_for_user(self, user_email: str) -> set[str]:
drive_service = get_drive_service(self.creds, user_email)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
logger.info(
f"Getting all drives for user {user_email} with service account: {is_service_account}"
)
all_drive_ids: set[str] = set()
for drive in execute_paginated_retrieval(
retrieval_function=drive_service.drives().list,

View File

@@ -423,6 +423,35 @@ class SavedSearchDoc(SearchDoc):
"""Create SavedSearchDoc from serialized dictionary data (e.g., from database JSON)"""
return cls(**data)
@classmethod
def from_url(cls, url: str) -> "SavedSearchDoc":
"""Create a SavedSearchDoc from a URL for internet search documents.
Uses the INTERNET_SEARCH_DOC_ prefix for document_id to match the format
used by inference sections created from internet content.
"""
return cls(
# db_doc_id can be a filler value since these docs are not saved to the database.
db_doc_id=0,
document_id="INTERNET_SEARCH_DOC_" + url,
chunk_ind=0,
semantic_identifier=url,
link=url,
blurb="",
source_type=DocumentSource.WEB,
boost=1,
hidden=False,
metadata={},
score=0.0,
is_relevant=None,
relevance_explanation=None,
match_highlights=[],
updated_at=None,
primary_owners=None,
secondary_owners=None,
is_internet=True,
)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, SavedSearchDoc):
return NotImplemented

View File

@@ -970,6 +970,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.search_docs, remove_doc_content=remove_doc_content
),
message_type=chat_message.message_type,
research_type=chat_message.research_type,
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],

View File

@@ -28,7 +28,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.entities import delete_from_kg_entities__no_commit
from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit
from onyx.db.enums import AccessType
@@ -55,7 +54,6 @@ from onyx.document_index.interfaces import DocumentMetadata
from onyx.kg.models import KGStage
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
@@ -304,80 +302,27 @@ def get_document_counts_for_cc_pairs(
]
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
def get_document_counts_for_all_cc_pairs(
db_session: Session,
) -> Sequence[tuple[int, int, int]]:
with get_session_with_current_tenant() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
"""Return (connector_id, credential_id, count) for ALL CC pairs with indexed docs.
def _get_document_counts_for_cc_pairs_batch(
batch: list[tuple[int, int]],
) -> list[tuple[int, int, int]]:
"""Worker for parallel execution: opens its own session per batch."""
if not batch:
return []
with get_session_with_current_tenant() as db_session:
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
func.count(),
)
.where(
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(batch),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
)
)
return db_session.execute(stmt).all() # type: ignore
def get_document_counts_for_cc_pairs_batched_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
batch_size: int = 1000,
max_workers: int | None = None,
) -> Sequence[tuple[int, int, int]]:
"""Parallel variant that batches the IN-clause and runs batches concurrently.
Opens an isolated DB session per batch to avoid sharing a session across threads.
Executes a single grouped query so Postgres can fully leverage indexes,
avoiding large batched IN-lists.
"""
if not cc_pairs:
return []
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
batches: list[list[tuple[int, int]]] = [
cc_ids[i : i + batch_size] for i in range(0, len(cc_ids), batch_size)
]
funcs = [(_get_document_counts_for_cc_pairs_batch, (batch,)) for batch in batches]
results = run_functions_tuples_in_parallel(
functions_with_args=funcs, max_workers=max_workers
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
func.count(),
)
.where(DocumentByConnectorCredentialPair.has_been_indexed.is_(True))
.group_by(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
)
)
aggregated_counts: dict[tuple[int, int], int] = {}
for batch_result in results:
if not batch_result:
continue
for connector_id, credential_id, cnt in batch_result:
aggregated_counts[(connector_id, credential_id)] = cnt
return [
(connector_id, credential_id, cnt)
for (connector_id, credential_id), cnt in aggregated_counts.items()
]
return db_session.execute(stmt).all() # type: ignore
def get_access_info_for_document(

View File

@@ -18,12 +18,14 @@ from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.federated import create_federated_connector_document_set_mapping
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet as DocumentSetDBModel
from onyx.db.models import DocumentSet__ConnectorCredentialPair
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserRole
@@ -370,10 +372,6 @@ def update_document_set(
db_session.add_all(ds_cc_pairs)
# Update federated connector mappings
from onyx.db.federated import create_federated_connector_document_set_mapping
from onyx.db.models import FederatedConnector__DocumentSet
from sqlalchemy import delete
# Delete existing federated connector mappings for this document set
delete_stmt = delete(FederatedConnector__DocumentSet).where(
FederatedConnector__DocumentSet.document_set_id == document_set_row.id
@@ -455,6 +453,13 @@ def mark_document_set_as_to_be_deleted(
db_session=db_session, document_set_id=document_set_id
)
# delete all federated connector mappings so the cleanup task can fully
# remove the document set once the Vespa sync completes
delete_stmt = delete(FederatedConnector__DocumentSet).where(
FederatedConnector__DocumentSet.document_set_id == document_set_id
)
db_session.execute(delete_stmt)
# delete all private document set information
versioned_delete_private_fn = fetch_versioned_implementation(
"onyx.db.document_set", "delete_document_set_privacy__no_commit"

View File

@@ -25,6 +25,32 @@ class IndexingStatus(str, PyEnum):
)
class PermissionSyncStatus(str, PyEnum):
"""Status enum for permission sync attempts"""
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
CANCELED = "canceled"
FAILED = "failed"
COMPLETED_WITH_ERRORS = "completed_with_errors"
def is_terminal(self) -> bool:
terminal_states = {
PermissionSyncStatus.SUCCESS,
PermissionSyncStatus.COMPLETED_WITH_ERRORS,
PermissionSyncStatus.CANCELED,
PermissionSyncStatus.FAILED,
}
return self in terminal_states
def is_successful(self) -> bool:
return (
self == PermissionSyncStatus.SUCCESS
or self == PermissionSyncStatus.COMPLETED_WITH_ERRORS
)
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"

View File

@@ -74,6 +74,7 @@ from onyx.db.enums import ChatSessionSharedStatus
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.enums import PermissionSyncStatus
from onyx.db.enums import TaskStatus
from onyx.db.pydantic_type import PydanticListType, PydanticType
from onyx.kg.models import KGEntityTypeAttributes
@@ -2473,6 +2474,7 @@ class Tool(Base):
mcp_server_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True
)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table
@@ -3435,6 +3437,8 @@ class ResearchAgentIterationSubStep(Base):
# for search-based step-types
cited_doc_results: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
claims: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
is_web_fetch: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
queries: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
# for image generation step-types
generated_images: Mapped[GeneratedImageFullResult | None] = mapped_column(
@@ -3599,3 +3603,145 @@ class MCPConnectionConfig(Base):
Index("ix_mcp_connection_config_user_email", "user_email"),
Index("ix_mcp_connection_config_server_user", "mcp_server_id", "user_email"),
)
"""
Permission Sync Tables
"""
class DocPermissionSyncAttempt(Base):
"""
Represents an attempt to sync document permissions for a connector credential pair.
Similar to IndexAttempt but specifically for document permission syncing operations.
"""
__tablename__ = "doc_permission_sync_attempt"
id: Mapped[int] = mapped_column(primary_key=True)
connector_credential_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=False,
)
# Status of the sync attempt
status: Mapped[PermissionSyncStatus] = mapped_column(
Enum(PermissionSyncStatus, native_enum=False, index=True)
)
# Counts for tracking progress
total_docs_synced: Mapped[int | None] = mapped_column(Integer, default=0)
docs_with_permission_errors: Mapped[int | None] = mapped_column(Integer, default=0)
# Error message if sync fails
error_message: Mapped[str | None] = mapped_column(Text, default=None)
# Timestamps
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
index=True,
)
time_started: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
time_finished: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# Relationships
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
"ConnectorCredentialPair"
)
__table_args__ = (
Index(
"ix_permission_sync_attempt_latest_for_cc_pair",
"connector_credential_pair_id",
"time_created",
),
Index(
"ix_permission_sync_attempt_status_time",
"status",
desc("time_finished"),
),
)
def __repr__(self) -> str:
return f"<DocPermissionSyncAttempt(id={self.id!r}, " f"status={self.status!r})>"
def is_finished(self) -> bool:
return self.status.is_terminal()
class ExternalGroupPermissionSyncAttempt(Base):
"""
Represents an attempt to sync external group memberships for users.
This tracks the syncing of user-to-external-group mappings across connectors.
"""
__tablename__ = "external_group_permission_sync_attempt"
id: Mapped[int] = mapped_column(primary_key=True)
# Can be tied to a specific connector or be a global group sync
connector_credential_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=True, # Nullable for global group syncs across all connectors
)
# Status of the group sync attempt
status: Mapped[PermissionSyncStatus] = mapped_column(
Enum(PermissionSyncStatus, native_enum=False, index=True)
)
# Counts for tracking progress
total_users_processed: Mapped[int | None] = mapped_column(Integer, default=0)
total_groups_processed: Mapped[int | None] = mapped_column(Integer, default=0)
total_group_memberships_synced: Mapped[int | None] = mapped_column(
Integer, default=0
)
# Error message if sync fails
error_message: Mapped[str | None] = mapped_column(Text, default=None)
# Timestamps
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
index=True,
)
time_started: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
time_finished: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# Relationships
connector_credential_pair: Mapped[ConnectorCredentialPair | None] = relationship(
"ConnectorCredentialPair"
)
__table_args__ = (
Index(
"ix_group_sync_attempt_cc_pair_time",
"connector_credential_pair_id",
"time_created",
),
Index(
"ix_group_sync_attempt_status_time",
"status",
desc("time_finished"),
),
)
def __repr__(self) -> str:
return (
f"<ExternalGroupPermissionSyncAttempt(id={self.id!r}, "
f"status={self.status!r})>"
)
def is_finished(self) -> bool:
return self.status.is_terminal()

View File

@@ -0,0 +1,485 @@
"""Permission sync attempt CRUD operations and utilities.
This module contains all CRUD operations for both DocPermissionSyncAttempt
and ExternalGroupPermissionSyncAttempt models, along with shared utilities.
"""
from typing import Any
from typing import cast
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.db.enums import PermissionSyncStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocPermissionSyncAttempt
from onyx.db.models import ExternalGroupPermissionSyncAttempt
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
# =============================================================================
# DOC PERMISSION SYNC ATTEMPT CRUD
# =============================================================================
def create_doc_permission_sync_attempt(
connector_credential_pair_id: int,
db_session: Session,
) -> int:
"""Create a new doc permission sync attempt.
Args:
connector_credential_pair_id: The ID of the connector credential pair
db_session: The database session
Returns:
The ID of the created attempt
"""
attempt = DocPermissionSyncAttempt(
connector_credential_pair_id=connector_credential_pair_id,
status=PermissionSyncStatus.NOT_STARTED,
)
db_session.add(attempt)
db_session.commit()
return attempt.id
def get_doc_permission_sync_attempt(
db_session: Session,
attempt_id: int,
eager_load_connector: bool = False,
) -> DocPermissionSyncAttempt | None:
"""Get a doc permission sync attempt by ID.
Args:
db_session: The database session
attempt_id: The ID of the attempt
eager_load_connector: If True, eagerly loads the connector and cc_pair relationships
Returns:
The attempt if found, None otherwise
"""
stmt = select(DocPermissionSyncAttempt).where(
DocPermissionSyncAttempt.id == attempt_id
)
if eager_load_connector:
stmt = stmt.options(
joinedload(DocPermissionSyncAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
)
)
return db_session.scalars(stmt).first()
def get_recent_doc_permission_sync_attempts_for_cc_pair(
cc_pair_id: int,
limit: int,
db_session: Session,
) -> list[DocPermissionSyncAttempt]:
"""Get recent doc permission sync attempts for a cc pair, most recent first."""
return list(
db_session.execute(
select(DocPermissionSyncAttempt)
.where(DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id)
.order_by(DocPermissionSyncAttempt.time_created.desc())
.limit(limit)
).scalars()
)
def mark_doc_permission_sync_attempt_in_progress(
attempt_id: int,
db_session: Session,
) -> DocPermissionSyncAttempt:
"""Mark a doc permission sync attempt as IN_PROGRESS.
Locks the row during update."""
try:
attempt = db_session.execute(
select(DocPermissionSyncAttempt)
.where(DocPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
if attempt.status != PermissionSyncStatus.NOT_STARTED:
raise RuntimeError(
f"Doc permission sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = PermissionSyncStatus.IN_PROGRESS
attempt.time_started = func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("mark_doc_permission_sync_attempt_in_progress exceptioned.")
raise
def mark_doc_permission_sync_attempt_failed(
attempt_id: int,
db_session: Session,
error_message: str,
) -> None:
"""Mark a doc permission sync attempt as failed."""
try:
attempt = db_session.execute(
select(DocPermissionSyncAttempt)
.where(DocPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = func.now() # type: ignore
attempt.status = PermissionSyncStatus.FAILED
attempt.time_finished = func.now() # type: ignore
attempt.error_message = error_message
db_session.commit()
# Add telemetry for permission sync attempt status change
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={
"doc_permission_sync_attempt_id": attempt_id,
"status": PermissionSyncStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
except Exception:
db_session.rollback()
raise
def complete_doc_permission_sync_attempt(
db_session: Session,
attempt_id: int,
total_docs_synced: int,
docs_with_permission_errors: int,
) -> DocPermissionSyncAttempt:
"""Complete a doc permission sync attempt by updating progress and setting final status.
This combines the progress update and final status marking into a single operation.
If there were permission errors, the attempt is marked as COMPLETED_WITH_ERRORS,
otherwise it's marked as SUCCESS.
Args:
db_session: The database session
attempt_id: The ID of the attempt
total_docs_synced: Total number of documents synced
docs_with_permission_errors: Number of documents that had permission errors
Returns:
The completed attempt
"""
try:
attempt = db_session.execute(
select(DocPermissionSyncAttempt)
.where(DocPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
# Update progress counters
attempt.total_docs_synced = (attempt.total_docs_synced or 0) + total_docs_synced
attempt.docs_with_permission_errors = (
attempt.docs_with_permission_errors or 0
) + docs_with_permission_errors
# Set final status based on whether there were errors
if docs_with_permission_errors > 0:
attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS
else:
attempt.status = PermissionSyncStatus.SUCCESS
attempt.time_finished = func.now() # type: ignore
db_session.commit()
# Add telemetry
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={
"doc_permission_sync_attempt_id": attempt_id,
"status": attempt.status.value,
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
return attempt
except Exception:
db_session.rollback()
logger.exception("complete_doc_permission_sync_attempt exceptioned.")
raise
# =============================================================================
# EXTERNAL GROUP PERMISSION SYNC ATTEMPT CRUD
# =============================================================================
def create_external_group_sync_attempt(
connector_credential_pair_id: int | None,
db_session: Session,
) -> int:
"""Create a new external group sync attempt.
Args:
connector_credential_pair_id: The ID of the connector credential pair, or None for global syncs
db_session: The database session
Returns:
The ID of the created attempt
"""
attempt = ExternalGroupPermissionSyncAttempt(
connector_credential_pair_id=connector_credential_pair_id,
status=PermissionSyncStatus.NOT_STARTED,
)
db_session.add(attempt)
db_session.commit()
return attempt.id
def get_external_group_sync_attempt(
db_session: Session,
attempt_id: int,
eager_load_connector: bool = False,
) -> ExternalGroupPermissionSyncAttempt | None:
"""Get an external group sync attempt by ID.
Args:
db_session: The database session
attempt_id: The ID of the attempt
eager_load_connector: If True, eagerly loads the connector and cc_pair relationships
Returns:
The attempt if found, None otherwise
"""
stmt = select(ExternalGroupPermissionSyncAttempt).where(
ExternalGroupPermissionSyncAttempt.id == attempt_id
)
if eager_load_connector:
stmt = stmt.options(
joinedload(
ExternalGroupPermissionSyncAttempt.connector_credential_pair
).joinedload(ConnectorCredentialPair.connector)
)
return db_session.scalars(stmt).first()
def get_recent_external_group_sync_attempts_for_cc_pair(
cc_pair_id: int | None,
limit: int,
db_session: Session,
) -> list[ExternalGroupPermissionSyncAttempt]:
"""Get recent external group sync attempts for a cc pair, most recent first.
If cc_pair_id is None, gets global group sync attempts."""
stmt = select(ExternalGroupPermissionSyncAttempt)
if cc_pair_id is not None:
stmt = stmt.where(
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id
== cc_pair_id
)
else:
stmt = stmt.where(
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id.is_(None)
)
return list(
db_session.execute(
stmt.order_by(ExternalGroupPermissionSyncAttempt.time_created.desc()).limit(
limit
)
).scalars()
)
def mark_external_group_sync_attempt_in_progress(
attempt_id: int,
db_session: Session,
) -> ExternalGroupPermissionSyncAttempt:
"""Mark an external group sync attempt as IN_PROGRESS.
Locks the row during update."""
try:
attempt = db_session.execute(
select(ExternalGroupPermissionSyncAttempt)
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
if attempt.status != PermissionSyncStatus.NOT_STARTED:
raise RuntimeError(
f"External group sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = PermissionSyncStatus.IN_PROGRESS
attempt.time_started = func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("mark_external_group_sync_attempt_in_progress exceptioned.")
raise
def mark_external_group_sync_attempt_failed(
attempt_id: int,
db_session: Session,
error_message: str,
) -> None:
"""Mark an external group sync attempt as failed."""
try:
attempt = db_session.execute(
select(ExternalGroupPermissionSyncAttempt)
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
if not attempt.time_started:
attempt.time_started = func.now() # type: ignore
attempt.status = PermissionSyncStatus.FAILED
attempt.time_finished = func.now() # type: ignore
attempt.error_message = error_message
db_session.commit()
# Add telemetry for permission sync attempt status change
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={
"external_group_sync_attempt_id": attempt_id,
"status": PermissionSyncStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
except Exception:
db_session.rollback()
raise
def complete_external_group_sync_attempt(
db_session: Session,
attempt_id: int,
total_users_processed: int,
total_groups_processed: int,
total_group_memberships_synced: int,
errors_encountered: int = 0,
) -> ExternalGroupPermissionSyncAttempt:
"""Complete an external group sync attempt by updating progress and setting final status.
This combines the progress update and final status marking into a single operation.
If there were errors, the attempt is marked as COMPLETED_WITH_ERRORS,
otherwise it's marked as SUCCESS.
Args:
db_session: The database session
attempt_id: The ID of the attempt
total_users_processed: Total users processed
total_groups_processed: Total groups processed
total_group_memberships_synced: Total group memberships synced
errors_encountered: Number of errors encountered (determines if COMPLETED_WITH_ERRORS)
Returns:
The completed attempt
"""
try:
attempt = db_session.execute(
select(ExternalGroupPermissionSyncAttempt)
.where(ExternalGroupPermissionSyncAttempt.id == attempt_id)
.with_for_update()
).scalar_one()
# Update progress counters
attempt.total_users_processed = (
attempt.total_users_processed or 0
) + total_users_processed
attempt.total_groups_processed = (
attempt.total_groups_processed or 0
) + total_groups_processed
attempt.total_group_memberships_synced = (
attempt.total_group_memberships_synced or 0
) + total_group_memberships_synced
# Set final status based on whether there were errors
if errors_encountered > 0:
attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS
else:
attempt.status = PermissionSyncStatus.SUCCESS
attempt.time_finished = func.now() # type: ignore
db_session.commit()
# Add telemetry
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={
"external_group_sync_attempt_id": attempt_id,
"status": attempt.status.value,
"cc_pair_id": attempt.connector_credential_pair_id,
},
)
return attempt
except Exception:
db_session.rollback()
logger.exception("complete_external_group_sync_attempt exceptioned.")
raise
# =============================================================================
# DELETION FUNCTIONS
# =============================================================================
def delete_doc_permission_sync_attempts__no_commit(
db_session: Session,
cc_pair_id: int,
) -> int:
"""Delete all doc permission sync attempts for a connector credential pair.
This does not commit the transaction. It should be used within an existing transaction.
Args:
db_session: The database session
cc_pair_id: The connector credential pair ID
Returns:
The number of attempts deleted
"""
stmt = delete(DocPermissionSyncAttempt).where(
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id
)
result = cast(CursorResult[Any], db_session.execute(stmt))
return result.rowcount or 0
def delete_external_group_permission_sync_attempts__no_commit(
db_session: Session,
cc_pair_id: int,
) -> int:
"""Delete all external group permission sync attempts for a connector credential pair.
This does not commit the transaction. It should be used within an existing transaction.
Args:
db_session: The database session
cc_pair_id: The connector credential pair ID
Returns:
The number of attempts deleted
"""
stmt = delete(ExternalGroupPermissionSyncAttempt).where(
ExternalGroupPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id
)
result = cast(CursorResult[Any], db_session.execute(stmt))
return result.rowcount or 0

View File

@@ -19,16 +19,23 @@ if TYPE_CHECKING:
logger = setup_logger()
def get_tools(db_session: Session) -> list[Tool]:
return list(db_session.scalars(select(Tool)).all())
def get_tools(db_session: Session, *, only_enabled: bool = False) -> list[Tool]:
query = select(Tool)
if only_enabled:
query = query.where(Tool.enabled.is_(True))
return list(db_session.scalars(query).all())
def get_tools_by_mcp_server_id(mcp_server_id: int, db_session: Session) -> list[Tool]:
return list(
db_session.scalars(
select(Tool).where(Tool.mcp_server_id == mcp_server_id)
).all()
)
def get_tools_by_mcp_server_id(
mcp_server_id: int,
db_session: Session,
*,
only_enabled: bool = False,
) -> list[Tool]:
query = select(Tool).where(Tool.mcp_server_id == mcp_server_id)
if only_enabled:
query = query.where(Tool.enabled.is_(True))
return list(db_session.scalars(query).all())
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
@@ -53,6 +60,9 @@ def create_tool__no_commit(
user_id: UUID | None,
db_session: Session,
passthrough_auth: bool,
*,
mcp_server_id: int | None = None,
enabled: bool = True,
) -> Tool:
new_tool = Tool(
name=name,
@@ -64,6 +74,8 @@ def create_tool__no_commit(
),
user_id=user_id,
passthrough_auth=passthrough_auth,
mcp_server_id=mcp_server_id,
enabled=enabled,
)
db_session.add(new_tool)
db_session.flush() # Don't commit yet, let caller decide when to commit

View File

@@ -1062,7 +1062,7 @@ class VespaIndex(DocumentIndex):
*,
tenant_id: str,
index_name: str,
) -> None:
) -> int:
"""
Deletes all entries in the specified index with the given tenant_id.
@@ -1072,6 +1072,9 @@ class VespaIndex(DocumentIndex):
Parameters:
tenant_id (str): The tenant ID whose documents are to be deleted.
index_name (str): The name of the index from which to delete documents.
Returns:
int: The number of documents deleted.
"""
logger.info(
f"Deleting entries with tenant_id: {tenant_id} from index: {index_name}"
@@ -1084,7 +1087,7 @@ class VespaIndex(DocumentIndex):
logger.info(
f"No documents found with tenant_id: {tenant_id} in index: {index_name}"
)
return
return 0
# Step 2: Delete documents in batches
delete_requests = [
@@ -1093,6 +1096,7 @@ class VespaIndex(DocumentIndex):
]
cls._apply_deletes_batched(delete_requests)
return len(document_ids)
@classmethod
def _get_all_document_ids_by_tenant_id(

View File

@@ -14,7 +14,7 @@ The evaluation system uses [Braintrust](https://www.braintrust.dev/) to run auto
Kick off a remote job
```bash
onyx/backend$ python onyx/evals/eval_cli.py --remote --api-key <SUPER_CLOUD_USER_API_KEY> --search-permissions-email <email account to reference> --remote --remote-dataset-name Simple
onyx/backend$ python -m dotenv -f .vscode/.env run -- python onyx/evals/eval_cli.py --remote --api-key <SUPER_CLOUD_USER_API_KEY> --search-permissions-email <email account to reference> --remote --remote-dataset-name Simple
```
You can also run the CLI directly from the command line:

View File

@@ -42,6 +42,7 @@ def run_local(
local_data_path: str | None,
remote_dataset_name: str | None,
search_permissions_email: str | None = None,
no_send_logs: bool = False,
) -> EvalationAck:
"""
Run evaluation with local configurations.
@@ -63,6 +64,7 @@ def run_local(
configuration = EvalConfigurationOptions(
search_permissions_email=search_permissions_email,
dataset_name=remote_dataset_name or "blank",
no_send_logs=no_send_logs,
)
if remote_dataset_name:
@@ -172,6 +174,13 @@ def main() -> None:
help="Email address to impersonate for the evaluation",
)
parser.add_argument(
"--no-send-logs",
action="store_true",
help="Do not send logs to the remote server",
default=False,
)
args = parser.parse_args()
if args.local_data_path:
@@ -215,6 +224,7 @@ def main() -> None:
local_data_path=args.local_data_path,
remote_dataset_name=args.remote_dataset_name,
search_permissions_email=args.search_permissions_email,
no_send_logs=args.no_send_logs,
)

View File

@@ -36,6 +36,7 @@ class EvalConfigurationOptions(BaseModel):
)
search_permissions_email: str
dataset_name: str
no_send_logs: bool = False
def get_configuration(self, db_session: Session) -> EvalConfiguration:
persona_override_config = self.persona_override_config or PersonaOverrideConfig(

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable
from typing import Any
from braintrust import Eval
from braintrust import EvalCase
@@ -23,33 +24,21 @@ class BraintrustEvalProvider(EvalProvider):
raise ValueError("Cannot specify both data and remote_dataset_name")
if data is None and remote_dataset_name is None:
raise ValueError("Must specify either data or remote_dataset_name")
eval_data: Any = None
if remote_dataset_name is not None:
eval_data = init_dataset(
project=BRAINTRUST_PROJECT, name=remote_dataset_name
)
Eval(
name=BRAINTRUST_PROJECT,
data=eval_data,
task=task,
scores=[],
metadata={**configuration.model_dump()},
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
)
else:
if data is None:
raise ValueError(
"Must specify data when remote_dataset_name is not specified"
)
eval_cases: list[EvalCase[dict[str, str], str]] = [
EvalCase(input=item["input"]) for item in data
]
Eval(
name=BRAINTRUST_PROJECT,
data=eval_cases,
task=task,
scores=[],
metadata={**configuration.model_dump()},
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
)
if data:
eval_data = [EvalCase(input=item["input"]) for item in data]
Eval(
name=BRAINTRUST_PROJECT,
data=eval_data, # type: ignore[arg-type]
task=task,
scores=[],
metadata={**configuration.model_dump()},
max_concurrency=BRAINTRUST_MAX_CONCURRENCY,
no_send_logs=configuration.no_send_logs,
)
return EvalationAck(success=True)

View File

@@ -1,13 +1,16 @@
import os
from typing import Any
import braintrust
from agents import set_trace_processors
from braintrust.wrappers.openai import BraintrustTracingProcessor
from braintrust_langchain import set_global_handler
from braintrust_langchain.callbacks import BraintrustCallbackHandler
from onyx.configs.app_configs import BRAINTRUST_API_KEY
from onyx.configs.app_configs import BRAINTRUST_PROJECT
MASKING_LENGTH = 20000
MASKING_LENGTH = int(os.environ.get("BRAINTRUST_MASKING_LENGTH", "20000"))
def _truncate_str(s: str) -> str:
@@ -26,10 +29,11 @@ def _mask(data: Any) -> Any:
def setup_braintrust() -> None:
"""Initialize Braintrust logger and set up global callback handler."""
braintrust.init_logger(
logger = braintrust.init_logger(
project=BRAINTRUST_PROJECT,
api_key=BRAINTRUST_API_KEY,
)
braintrust.set_masking_function(_mask)
handler = BraintrustCallbackHandler()
set_global_handler(handler)
set_trace_processors([BraintrustTracingProcessor(logger)])

View File

View File

@@ -0,0 +1,28 @@
from onyx.feature_flags.interface import FeatureFlagProvider
from onyx.feature_flags.interface import NoOpFeatureFlagProvider
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from shared_configs.configs import MULTI_TENANT
def get_default_feature_flag_provider() -> FeatureFlagProvider:
"""
Get the default feature flag provider implementation.
Returns the PostHog-based provider in Enterprise Edition when available,
otherwise returns a no-op provider that always returns False.
This function is designed for dependency injection - callers should
use this factory rather than directly instantiating providers.
Returns:
FeatureFlagProvider: The configured feature flag provider instance
"""
if MULTI_TENANT:
return fetch_versioned_implementation_with_fallback(
module="onyx.feature_flags.factory",
attribute="get_posthog_feature_flag_provider",
fallback=lambda: NoOpFeatureFlagProvider(),
)()
return NoOpFeatureFlagProvider()

View File

@@ -0,0 +1,6 @@
"""
Feature flag keys used throughout the application.
Centralizes feature flag key definitions to avoid magic strings.
"""
SIMPLE_AGENT_FRAMEWORK = "simple-agent-framework"

View File

View File

@@ -0,0 +1,72 @@
import abc
from typing import Any
from uuid import UUID
from onyx.db.models import User
from shared_configs.configs import ENVIRONMENT
class FeatureFlagProvider(abc.ABC):
"""
Abstract base class for feature flag providers.
Implementations should provide vendor-specific logic for checking
whether a feature flag is enabled for a given user.
"""
@abc.abstractmethod
def feature_enabled(
self,
flag_key: str,
user_id: UUID,
user_properties: dict[str, Any] | None = None,
) -> bool:
"""
Check if a feature flag is enabled for a user.
Args:
flag_key: The identifier for the feature flag to check
user_id: The unique identifier for the user
user_properties: Optional dictionary of user properties/attributes
that may influence flag evaluation
Returns:
True if the feature is enabled for the user, False otherwise
"""
raise NotImplementedError
def feature_enabled_for_user_tenant(
self, flag_key: str, user: User | None, tenant_id: str
) -> bool:
"""
Check if a feature flag is enabled for a user.
"""
return self.feature_enabled(
flag_key,
# For local dev with AUTH_TYPE=disabled, we don't have a user, so we use a random UUID
user.id if user else UUID("caa1e0cd-6ee6-4550-b1ec-8affaef4bf83"),
user_properties={
"tenant_id": tenant_id,
"email": user.email if user else "anonymous@onyx.app",
},
)
class NoOpFeatureFlagProvider(FeatureFlagProvider):
"""
No-operation feature flag provider that always returns False.
Used as a fallback when no real feature flag provider is available
(e.g., in MIT version without PostHog).
"""
def feature_enabled(
self,
flag_key: str,
user_id: UUID,
user_properties: dict[str, Any] | None = None,
) -> bool:
environment = ENVIRONMENT
if environment == "local":
return True
return False

View File

@@ -5,8 +5,8 @@ All other modules should import litellm from here instead of directly.
"""
import litellm
from agents.extensions.models.litellm_model import LitellmModel
from onyx.configs.app_configs import BRAINTRUST_ENABLED
# Import litellm
@@ -16,8 +16,5 @@ from onyx.configs.app_configs import BRAINTRUST_ENABLED
litellm.drop_params = True
litellm.telemetry = False
if BRAINTRUST_ENABLED:
litellm.callbacks = ["braintrust"]
# Export the configured litellm module
__all__ = ["litellm"]
__all__ = ["litellm", "LitellmModel"]

View File

@@ -57,6 +57,22 @@ class PreviousMessage(BaseModel):
research_answer_purpose=chat_message.research_answer_purpose,
)
def to_agent_sdk_msg(self) -> dict:
message_type_to_agent_sdk_role = {
MessageType.USER: "user",
MessageType.SYSTEM: "system",
MessageType.ASSISTANT: "assistant",
}
# TODO: Use native format for files and images
content = build_content_with_imgs(self.message, self.files)
if self.message_type in message_type_to_agent_sdk_role:
role = message_type_to_agent_sdk_role[self.message_type]
return {
"role": role,
"content": content,
}
raise ValueError(f"Unknown message type: {self.message_type}")
def to_langchain_msg(self) -> BaseMessage:
content = build_content_with_imgs(self.message, self.files)
if self.message_type == MessageType.USER:
@@ -66,6 +82,7 @@ class PreviousMessage(BaseModel):
else:
return SystemMessage(content=content)
# TODO: deprecate langchain
@classmethod
def from_langchain_msg(
cls, msg: BaseMessage, token_count: int

View File

@@ -10,6 +10,13 @@ Avoid using double brackets like [[1]]. To cite multiple documents, use [1], [2]
Try to cite inline as opposed to leaving all citations until the very end of the response.
""".rstrip()
REQUIRE_CITATION_STATEMENT_V2 = """
Cite relevant statements INLINE using the format [[1]](https://example.com) with the document number (an integer) in between
the brackets. To cite multiple documents, use [[1]](https://example.com), [[2]](https://example.com) format instead of \
[[1, 2]](https://example.com). \
Try to cite inline as opposed to leaving all citations until the very end of the response.
""".rstrip()
NO_CITATION_STATEMENT = """
Do not provide any citations even if there are examples in the chat history.
""".rstrip()

View File

@@ -54,6 +54,12 @@ CONVERSATION HISTORY:
{GENERAL_SEP_PAT}
"""
COMPANY_NAME_BLOCK = """
The user works at {company_name}.
"""
COMPANY_DESCRIPTION_BLOCK = """Organization description: {company_description}
"""
# This has to be doubly escaped due to json containing { } which are also used for format strings
EMPTY_SAMPLE_JSON = {

View File

@@ -13,6 +13,9 @@ from onyx.db.models import Persona
from onyx.prompts.chat_prompts import ADDITIONAL_INFO
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.prompts.direct_qa_prompts import COMPANY_DESCRIPTION_BLOCK
from onyx.prompts.direct_qa_prompts import COMPANY_NAME_BLOCK
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -89,6 +92,23 @@ def handle_onyx_date_awareness(
return prompt_str
def handle_company_awareness(prompt_str: str) -> str:
try:
workspace_settings = load_settings()
company_name = workspace_settings.company_name
company_description = workspace_settings.company_description
if company_name:
prompt_str += COMPANY_NAME_BLOCK.format(company_name=company_name)
if company_description:
prompt_str += COMPANY_DESCRIPTION_BLOCK.format(
company_description=company_description
)
return prompt_str
except Exception as e:
logger.error(f"Error handling company awareness: {e}")
return prompt_str
def build_task_prompt_reminders(
prompt: Persona | PromptConfig,
use_language_hint: bool,

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from logging import Logger
from typing import Any
from typing import cast
from typing import NamedTuple
import redis
from pydantic import BaseModel
@@ -16,6 +17,18 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.variable_functionality import fetch_versioned_implementation
class PermissionSyncResult(NamedTuple):
"""Result of a permission sync operation.
Attributes:
num_updated: Number of documents successfully updated
num_errors: Number of documents that failed to update
"""
num_updated: int
num_errors: int
class RedisConnectorPermissionSyncPayload(BaseModel):
id: str
submitted: datetime
@@ -159,7 +172,12 @@ class RedisConnectorPermissionSync:
connector_id: int,
credential_id: int,
task_logger: Logger | None = None,
) -> int | None:
) -> PermissionSyncResult:
"""Update permissions for documents.
Returns:
PermissionSyncResult containing counts of successful updates and errors
"""
last_lock_time = time.monotonic()
document_update_permissions_fn = fetch_versioned_implementation(
@@ -168,6 +186,7 @@ class RedisConnectorPermissionSync:
)
num_permissions = 0
num_errors = 0
# Create a task for each document permission sync
for permissions in new_permissions:
current_time = time.monotonic()
@@ -201,14 +220,25 @@ class RedisConnectorPermissionSync:
# a rare enough case to be acceptable.
# This can internally exception due to db issues but still continue
# we may want to change this
document_update_permissions_fn(
self.tenant_id, permissions, source_string, connector_id, credential_id
)
# Catch exceptions per-document to avoid breaking the entire sync
try:
document_update_permissions_fn(
self.tenant_id,
permissions,
source_string,
connector_id,
credential_id,
)
num_permissions += 1
except Exception:
num_errors += 1
if task_logger:
task_logger.exception(
f"Failed to update permissions for document {permissions.doc_id}"
)
# Continue processing other documents
num_permissions += 1
return num_permissions
return PermissionSyncResult(num_updated=num_permissions, num_errors=num_errors)
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)

View File

@@ -91,7 +91,7 @@ from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs_batched_parallel
from onyx.db.document import get_document_counts_for_all_cc_pairs
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -795,15 +795,7 @@ def get_connector_indexing_status(
list[IndexAttempt], latest_finished_index_attempts
)
document_count_info = get_document_counts_for_cc_pairs_batched_parallel(
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in non_editable_cc_pairs + editable_cc_pairs
]
)
document_count_info = get_document_counts_for_all_cc_pairs(db_session)
# Create lookup dictionaries for efficient access
cc_pair_to_document_cnt: dict[tuple[int, int], int] = {

View File

@@ -22,6 +22,7 @@ from mcp.shared.auth import OAuthClientInformationFull
from mcp.shared.auth import OAuthClientMetadata
from mcp.shared.auth import OAuthToken
from mcp.types import InitializeResult
from mcp.types import Tool as MCPLibTool
from pydantic import AnyUrl
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -50,6 +51,7 @@ from onyx.db.mcp import update_mcp_server__no_commit
from onyx.db.mcp import upsert_user_connection_config
from onyx.db.models import MCPConnectionConfig
from onyx.db.models import MCPServer as DbMCPServer
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.tools import create_tool__no_commit
from onyx.db.tools import delete_tool__no_commit
@@ -1044,6 +1046,55 @@ def user_list_mcp_tools_by_id(
return _list_mcp_tools_by_id(server_id, db, False, user)
def _upsert_db_tools(
discovered_tools: list[MCPLibTool],
existing_by_name: dict[str, Tool],
processed_names: set[str],
mcp_server_id: int,
db: Session,
) -> bool:
db_dirty = False
for tool in discovered_tools:
tool_name = tool.name
if not tool_name:
continue
processed_names.add(tool_name)
description = tool.description or ""
annotations_title = tool.annotations.title if tool.annotations else None
display_name = tool.title or annotations_title or tool_name
input_schema = tool.inputSchema
if existing_tool := existing_by_name.get(tool_name):
if existing_tool.description != description:
existing_tool.description = description
db_dirty = True
if existing_tool.display_name != display_name:
existing_tool.display_name = display_name
db_dirty = True
if existing_tool.mcp_input_schema != input_schema:
existing_tool.mcp_input_schema = input_schema
db_dirty = True
continue
new_tool = create_tool__no_commit(
name=tool_name,
description=description,
openapi_schema=None,
custom_headers=None,
user_id=None,
db_session=db,
passthrough_auth=False,
mcp_server_id=mcp_server_id,
enabled=False,
)
new_tool.display_name = display_name
new_tool.mcp_input_schema = input_schema
db_dirty = True
return db_dirty
def _list_mcp_tools_by_id(
server_id: int,
db: Session,
@@ -1090,18 +1141,35 @@ def _list_mcp_tools_by_id(
logger.info(f"Discovering tools for MCP server: {mcp_server.name}: {t1}")
# Normalize URL to include trailing slash to avoid redirect/slow path handling
server_url = mcp_server.server_url.rstrip("/") + "/"
tools = discover_mcp_tools(
discovered_tools = discover_mcp_tools(
server_url,
connection_config.config.get("headers", {}) if connection_config else {},
transport=mcp_server.transport,
auth=auth,
)
logger.info(
f"Discovered {len(tools)} tools for MCP server: {mcp_server.name}: {time.time() - t1}"
f"Discovered {len(discovered_tools)} tools for MCP server: {mcp_server.name}: {time.time() - t1}"
)
if is_admin:
existing_tools = get_tools_by_mcp_server_id(mcp_server.id, db)
existing_by_name = {db_tool.name: db_tool for db_tool in existing_tools}
processed_names: set[str] = set()
db_dirty = _upsert_db_tools(
discovered_tools, existing_by_name, processed_names, mcp_server.id, db
)
for name, db_tool in existing_by_name.items():
if name not in processed_names:
delete_tool__no_commit(db_tool.id, db)
db_dirty = True
if db_dirty:
db.commit()
# Truncate tool descriptions to prevent overly long responses
for tool in tools:
for tool in discovered_tools:
if tool.description:
tool.description = _truncate_description(tool.description)
@@ -1112,7 +1180,7 @@ def _list_mcp_tools_by_id(
server_id=server_id,
server_name=mcp_server.name,
server_url=mcp_server.server_url,
tools=tools,
tools=discovered_tools,
)
@@ -1324,70 +1392,31 @@ def _upsert_mcp_server(
return mcp_server
def _add_tools_to_server(
def _sync_tools_for_server(
mcp_server: DbMCPServer,
selected_tools: list[str],
keep_tool_names: set[str],
user: User | None,
selected_tools: set[str],
db_session: Session,
) -> int:
created_tools = 0
# First, discover available tools from the server to get full definitions
"""Toggle enabled state for MCP tools that exist for the server.
Updates to the db model of a tool all happen when the user Lists Tools.
This ensures that the the tools added to the db match what the user sees in the UI,
even if the underlying tool has changed on the server after list tools is called.
That's a corner case anyways; the admin should go back and update the server by re-listing tools.
"""
connection_config = _get_connection_config(mcp_server, True, user, db_session)
headers = connection_config.config.get("headers", {}) if connection_config else {}
updated_tools = 0
auth = None
if mcp_server.auth_type == MCPAuthenticationType.OAUTH:
user_id = str(user.id) if user else ""
assert connection_config
auth = make_oauth_provider(
mcp_server,
user_id,
UNUSED_RETURN_PATH,
connection_config.id,
mcp_server.admin_connection_config_id,
)
available_tools = discover_mcp_tools(
mcp_server.server_url,
headers,
transport=mcp_server.transport,
auth=auth,
)
tools_by_name = {tool.name: tool for tool in available_tools}
existing_tools = get_tools_by_mcp_server_id(mcp_server.id, db_session)
existing_by_name = {tool.name: tool for tool in existing_tools}
for tool_name in selected_tools:
if tool_name not in tools_by_name:
logger.warning(f"Tool '{tool_name}' not found in MCP server")
continue
# Disable any existing tools that were not processed above
for tool_name, db_tool in existing_by_name.items():
should_enable = tool_name in selected_tools
if db_tool.enabled != should_enable:
db_tool.enabled = should_enable
updated_tools += 1
if tool_name in keep_tool_names:
# tool was not deleted earlier and not added now
continue
tool_def = tools_by_name[tool_name]
# Create Tool entry for each selected tool
tool = create_tool__no_commit(
name=tool_name,
description=_truncate_description(tool_def.description),
openapi_schema=None, # MCP tools don't use OpenAPI
custom_headers=None,
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=False,
)
# Update the tool with MCP server ID, display name, and input schema
tool.mcp_server_id = mcp_server.id
annotations_title = tool_def.annotations.title if tool_def.annotations else None
tool.display_name = tool_def.title or annotations_title or tool_name
tool.mcp_input_schema = tool_def.inputSchema
created_tools += 1
logger.info(f"Created MCP tool '{tool.name}' with ID {tool.id}")
return created_tools
return updated_tools
@admin_router.get("/servers/{server_id}", response_model=MCPServer)
@@ -1560,27 +1589,12 @@ def update_mcp_server_with_tools(
status_code=400, detail="MCP server has no admin connection config"
)
# Cleanup: Delete tools for this server that are not in the selected_tools list
selected_names = set(request.selected_tools or [])
existing_tools = get_tools_by_mcp_server_id(request.server_id, db_session)
keep_tool_names = set()
updated_tools = 0
for tool in existing_tools:
if tool.name in selected_names:
keep_tool_names.add(tool.name)
else:
delete_tool__no_commit(tool.id, db_session)
updated_tools += 1
# If selected_tools is provided, create individual tools for each
if request.selected_tools:
updated_tools += _add_tools_to_server(
mcp_server,
request.selected_tools,
keep_tool_names,
user,
db_session,
)
updated_tools = _sync_tools_for_server(
mcp_server,
selected_names,
db_session,
)
db_session.commit()

View File

@@ -64,6 +64,7 @@ def create_custom_tool(
user_id=user.id if user else None,
db_session=db_session,
passthrough_auth=tool_data.passthrough_auth,
enabled=True,
)
db_session.commit()
return ToolSnapshot.from_model(tool)
@@ -147,7 +148,7 @@ def list_tools(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> list[ToolSnapshot]:
tools = get_tools(db_session)
tools = get_tools(db_session, only_enabled=True)
filtered_tools: list[ToolSnapshot] = []
for tool in tools:

View File

@@ -15,6 +15,7 @@ from fastapi import Request
from fastapi import Response
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from redis.client import Redis
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accessible_user
@@ -25,6 +26,7 @@ from onyx.chat.process_message import stream_chat_message
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from onyx.chat.stop_signal_checker import set_fence
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import MessageType
@@ -59,6 +61,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)
@@ -816,3 +819,17 @@ async def search_chats(
has_more=has_more,
next_page=page + 1 if has_more else None,
)
@router.post("/stop-chat-session/{chat_session_id}")
def stop_chat_session(
chat_session_id: UUID,
user: User | None = Depends(current_user),
redis_client: Redis = Depends(get_redis_client),
) -> dict[str, str]:
"""
Stop a chat session by setting a stop signal in Redis.
This endpoint is called by the frontend when the user clicks the stop button.
"""
set_fence(chat_session_id, redis_client, True)
return {"message": "Chat session stopped"}

View File

@@ -6,6 +6,7 @@ from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from onyx.agents.agent_search.dr.enums import ResearchType
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import RetrievalDocs
from onyx.configs.constants import DocumentSource
@@ -255,6 +256,7 @@ class ChatMessageDetail(BaseModel):
rephrased_query: str | None = None
context_docs: RetrievalDocs | None = None
message_type: MessageType
research_type: ResearchType | None = None
time_sent: datetime
overridden_model: str | None
alternate_assistant_id: int | None = None

View File

@@ -36,6 +36,12 @@ class MessageDelta(BaseObj):
"""Control Packets"""
class PacketException(BaseObj):
type: Literal["error"] = "error"
exception: Exception
model_config = {"arbitrary_types_allowed": True}
class OverallStop(BaseObj):
type: Literal["stop"] = "stop"
@@ -56,8 +62,14 @@ class SearchToolStart(BaseObj):
class SearchToolDelta(BaseObj):
type: Literal["internal_search_tool_delta"] = "internal_search_tool_delta"
queries: list[str] | None = None
documents: list[SavedSearchDoc] | None = None
queries: list[str]
documents: list[SavedSearchDoc]
class FetchToolStart(BaseObj):
type: Literal["fetch_tool_start"] = "fetch_tool_start"
documents: list[SavedSearchDoc]
class ImageGenerationToolStart(BaseObj):
@@ -182,6 +194,8 @@ PacketObj = Annotated[
ReasoningDelta,
CitationStart,
CitationDelta,
PacketException,
FetchToolStart,
],
Field(discriminator="type"),
]

View File

@@ -17,12 +17,15 @@ from onyx.db.chat import get_db_search_doc_by_id
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.models import ChatMessage
from onyx.db.tools import get_tool_by_id
from onyx.feature_flags.factory import get_default_feature_flag_provider
from onyx.feature_flags.feature_flags_keys import SIMPLE_AGENT_FRAMEWORK
from onyx.server.query_and_chat.streaming_models import CitationDelta
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import CitationStart
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import EndStepPacketList
from onyx.server.query_and_chat.streaming_models import FetchToolStart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import MessageDelta
@@ -45,6 +48,7 @@ from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from shared_configs.contextvars import get_current_tenant_id
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
@@ -213,9 +217,20 @@ def create_custom_tool_packets(
return packets
def create_fetch_packets(
fetches: list[list[SavedSearchDoc]], step_nr: int
) -> list[Packet]:
packets: list[Packet] = []
for fetch in fetches:
packets.append(Packet(ind=step_nr, obj=FetchToolStart(documents=fetch)))
packets.append(Packet(ind=step_nr, obj=SectionEnd(type="section_end")))
step_nr += 1
return packets
def create_search_packets(
search_queries: list[str],
saved_search_docs: list[SavedSearchDoc] | None,
saved_search_docs: list[SavedSearchDoc],
is_internet_search: bool,
step_nr: int,
) -> list[Packet]:
@@ -245,12 +260,253 @@ def create_search_packets(
return packets
def translate_db_message_to_packets_simple(
chat_message: ChatMessage,
db_session: Session,
remove_doc_content: bool = False,
start_step_nr: int = 1,
) -> EndStepPacketList:
"""
Translation function for simple agent framework (ResearchType.FAST).
Includes support for FetchToolStart packets for web fetch operations.
"""
step_nr = start_step_nr
packet_list: list[Packet] = []
if chat_message.message_type == MessageType.ASSISTANT:
citations = chat_message.citations
citation_info_list: list[CitationInfo] = []
if citations:
for citation_num, search_doc_id in citations.items():
search_doc = get_db_search_doc_by_id(search_doc_id, db_session)
if search_doc:
citation_info_list.append(
CitationInfo(
citation_num=citation_num,
document_id=search_doc.document_id,
)
)
elif chat_message.search_docs:
for i, search_doc in enumerate(chat_message.search_docs):
citation_info_list.append(
CitationInfo(
citation_num=i,
document_id=search_doc.document_id,
)
)
research_iterations = []
if chat_message.research_type in [
ResearchType.THOUGHTFUL,
ResearchType.DEEP,
ResearchType.LEGACY_AGENTIC,
ResearchType.FAST,
]:
research_iterations = sorted(
chat_message.research_iterations, key=lambda x: x.iteration_nr
)
for research_iteration in research_iterations:
if (
research_iteration.iteration_nr > 1
and research_iteration.reasoning
and chat_message.research_type == ResearchType.DEEP
):
packet_list.extend(
create_reasoning_packets(research_iteration.reasoning, step_nr)
)
step_nr += 1
if (
research_iteration.purpose
and chat_message.research_type == ResearchType.DEEP
):
packet_list.extend(
create_reasoning_packets(research_iteration.purpose, step_nr)
)
step_nr += 1
sub_steps = research_iteration.sub_steps
tasks: list[str] = []
tool_call_ids: list[int | None] = []
cited_docs: list[SavedSearchDoc] = []
fetches: list[list[SavedSearchDoc]] = []
is_web_fetch: bool = False
for sub_step in sub_steps:
# For v2 tools, use the queries field if available, otherwise fall back to sub_step_instructions
if sub_step.queries:
tasks.extend(sub_step.queries)
else:
tasks.append(sub_step.sub_step_instructions or "")
tool_call_ids.append(sub_step.sub_step_tool_id)
sub_step_cited_docs = sub_step.cited_doc_results
sub_step_saved_search_docs: list[SavedSearchDoc] = []
if isinstance(sub_step_cited_docs, list):
for doc_data in sub_step_cited_docs:
doc_data["db_doc_id"] = 1
doc_data["boost"] = 1
doc_data["hidden"] = False
doc_data["chunk_ind"] = 0
if (
doc_data["updated_at"] is None
or doc_data["updated_at"] == "None"
):
doc_data["updated_at"] = datetime.now()
sub_step_saved_search_docs.append(
SavedSearchDoc.from_dict(doc_data)
if isinstance(doc_data, dict)
else doc_data
)
cited_docs.extend(sub_step_saved_search_docs)
else:
if chat_message.research_type == ResearchType.DEEP:
packet_list.extend(
create_reasoning_packets(
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
)
)
step_nr += 1
if sub_step.is_web_fetch and len(sub_step_saved_search_docs) > 0:
is_web_fetch = True
fetches.append(sub_step_saved_search_docs)
if len(set(tool_call_ids)) > 1:
if chat_message.research_type == ResearchType.DEEP:
packet_list.extend(
create_reasoning_packets(
_CANNOT_SHOW_STEP_RESULTS_STR, step_nr
)
)
step_nr += 1
elif len(sub_steps) == 0:
# no sub steps, no tool calls. But iteration can have reasoning or purpose
continue
else:
tool_id = tool_call_ids[0]
if not tool_id:
raise ValueError("Tool ID is required")
tool = get_tool_by_id(tool_id, db_session)
tool_name = tool.name
if tool_name in [SearchTool.__name__, KnowledgeGraphTool.__name__]:
cited_docs = cast(list[SavedSearchDoc], cited_docs)
packet_list.extend(
create_search_packets(tasks, cited_docs, False, step_nr)
)
step_nr += 1
elif tool_name == WebSearchTool.__name__:
cited_docs = cast(list[SavedSearchDoc], cited_docs)
if is_web_fetch:
packet_list.extend(create_fetch_packets(fetches, step_nr))
else:
packet_list.extend(
create_search_packets(tasks, cited_docs, True, step_nr)
)
step_nr += 1
elif tool_name == ImageGenerationTool.__name__:
if sub_step.generated_images is None:
raise ValueError("No generated images found")
packet_list.extend(
create_image_generation_packets(
sub_step.generated_images.images, step_nr
)
)
step_nr += 1
elif tool_name == OktaProfileTool.__name__:
packet_list.extend(
create_custom_tool_packets(
tool_name=tool_name,
response_type="text",
step_nr=step_nr,
data=sub_step.sub_answer,
)
)
step_nr += 1
else:
packet_list.extend(
create_custom_tool_packets(
tool_name=tool_name,
response_type="text",
step_nr=step_nr,
data=sub_step.sub_answer,
)
)
step_nr += 1
if chat_message.message:
packet_list.extend(
create_message_packets(
message_text=chat_message.message,
final_documents=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in chat_message.search_docs
],
step_nr=step_nr,
is_legacy_agentic=chat_message.research_type
== ResearchType.LEGACY_AGENTIC,
)
)
step_nr += 1
if len(citation_info_list) > 0 and len(research_iterations) == 0:
saved_search_docs: list[SavedSearchDoc] = []
for citation_info in citation_info_list:
cited_doc = get_db_search_doc_by_document_id(
citation_info.document_id, db_session
)
if cited_doc:
saved_search_docs.append(
translate_db_search_doc_to_server_search_doc(cited_doc)
)
packet_list.extend(
create_search_packets([], saved_search_docs, False, step_nr)
)
step_nr += 1
return EndStepPacketList(packet_list=packet_list, end_step_nr=step_nr)
def translate_db_message_to_packets(
chat_message: ChatMessage,
db_session: Session,
remove_doc_content: bool = False,
start_step_nr: int = 1,
) -> EndStepPacketList:
use_simple_translation = False
if chat_message.research_type and chat_message.research_type != ResearchType.DEEP:
feature_flag_provider = get_default_feature_flag_provider()
tenant_id = get_current_tenant_id()
user = chat_message.chat_session.user
use_simple_translation = feature_flag_provider.feature_enabled_for_user_tenant(
flag_key=SIMPLE_AGENT_FRAMEWORK,
user=user,
tenant_id=tenant_id,
)
if use_simple_translation:
return translate_db_message_to_packets_simple(
chat_message=chat_message,
db_session=db_session,
remove_doc_content=remove_doc_content,
start_step_nr=start_step_nr,
)
step_nr = start_step_nr
packet_list: list[Packet] = []
@@ -305,7 +561,11 @@ def translate_db_message_to_packets(
cited_docs: list[SavedSearchDoc] = []
for sub_step in sub_steps:
tasks.append(sub_step.sub_step_instructions or "")
# For v2 tools, use the queries field if available, otherwise fall back to sub_step_instructions
if sub_step.queries:
tasks.extend(sub_step.queries)
else:
tasks.append(sub_step.sub_step_instructions or "")
tool_call_ids.append(sub_step.sub_step_tool_id)
sub_step_cited_docs = sub_step.cited_doc_results

View File

@@ -45,6 +45,8 @@ class Settings(BaseModel):
# is float to allow for fractional days for easier automated testing
maximum_chat_retention_days: float | None = None
company_name: str | None = None
company_description: str | None = None
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None

View File

@@ -0,0 +1,123 @@
# create adapter from Tool to FunctionTool
import json
from typing import Any
from typing import Union
from agents import FunctionTool
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.turn.models import ChatTurnContext
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.built_in_tools_v2 import BUILT_IN_TOOL_MAP_V2
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
from onyx.tools.tool_implementations.mcp.mcp_tool import MCPTool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
# Type alias for tools that need custom handling
CustomOrMcpTool = Union[CustomTool, MCPTool]
def is_custom_or_mcp_tool(tool: Tool) -> bool:
"""Check if a tool is a CustomTool or MCPTool."""
return isinstance(tool, CustomTool) or isinstance(tool, MCPTool)
@tool_accounting
async def _tool_run_wrapper(
run_context: RunContextWrapper[ChatTurnContext], tool: Tool, json_string: str
) -> list[Any]:
"""
Wrapper function to adapt Tool.run() to FunctionTool.on_invoke_tool() signature.
"""
args = json.loads(json_string) if json_string else {}
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolStart(type="custom_tool_start", tool_name=tool.name),
)
)
results = []
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan=f"Running {tool.name}",
purpose=f"Running {tool.name}",
reasoning=f"Running {tool.name}",
)
)
for result in tool.run(**args):
results.append(result)
# Extract data from CustomToolCallSummary within the ToolResponse
custom_summary = result.response
data = None
file_ids = None
# Handle different response types
if custom_summary.response_type in ["image", "csv"] and hasattr(
custom_summary.tool_result, "file_ids"
):
file_ids = custom_summary.tool_result.file_ids
else:
data = custom_summary.tool_result
run_context.context.aggregated_context.global_iteration_responses.append(
IterationAnswer(
tool=tool.name,
tool_id=tool.id,
iteration_nr=index,
parallelization_nr=0,
question=json.dumps(args) if args else "",
reasoning=f"Running {tool.name}",
data=data,
file_ids=file_ids,
cited_documents={},
additional_data=None,
response_type=custom_summary.response_type,
answer=str(data) if data else str(file_ids),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=CustomToolDelta(
type="custom_tool_delta",
tool_name=tool.name,
response_type=custom_summary.response_type,
data=data,
file_ids=file_ids,
),
)
)
return results
def tool_to_function_tool(tool: Tool) -> FunctionTool:
return FunctionTool(
name=tool.name,
description=tool.description,
params_json_schema=tool.tool_definition()["function"]["parameters"],
on_invoke_tool=lambda context, json_string: _tool_run_wrapper(
context, tool, json_string
),
)
def tools_to_function_tools(tools: list[Tool]) -> list[FunctionTool]:
onyx_tools: list[list[FunctionTool]] = [
BUILT_IN_TOOL_MAP_V2[type(tool).__name__]
for tool in tools
if type(tool).__name__ in BUILT_IN_TOOL_MAP_V2
]
flattened_builtin_tools: list[FunctionTool] = [
onyx_tool for sublist in onyx_tools for onyx_tool in sublist
]
custom_and_mcp_tools: list[FunctionTool] = [
tool_to_function_tool(tool) for tool in tools if is_custom_or_mcp_tool(tool)
]
return flattened_builtin_tools + custom_and_mcp_tools

View File

@@ -0,0 +1,24 @@
from agents import FunctionTool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import (
WebSearchTool,
)
from onyx.tools.tool_implementations_v2.image_generation import image_generation_tool
from onyx.tools.tool_implementations_v2.internal_search import internal_search_tool
from onyx.tools.tool_implementations_v2.okta_profile import okta_profile_tool
from onyx.tools.tool_implementations_v2.web import web_fetch_tool
from onyx.tools.tool_implementations_v2.web import web_search_tool
BUILT_IN_TOOL_MAP_V2: dict[str, list[FunctionTool]] = {
SearchTool.__name__: [internal_search_tool],
ImageGenerationTool.__name__: [image_generation_tool],
WebSearchTool.__name__: [web_search_tool, web_fetch_tool],
OktaProfileTool.__name__: [okta_profile_tool],
}

View File

@@ -26,7 +26,7 @@ from pydantic import BaseModel
from onyx.db.enums import MCPTransport
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_async_sync
from onyx.utils.threadpool_concurrency import run_async_sync_no_cancel
logger = setup_logger()
@@ -203,7 +203,7 @@ def _call_mcp_client_function_sync(
function, server_url, connection_headers, transport, auth, **kwargs
)
try:
return run_async_sync(run_client_function())
return run_async_sync_no_cancel(run_client_function())
except Exception as e:
logger.error(f"Failed to call MCP client function: {e}")
if isinstance(e, ExceptionGroup):

View File

@@ -0,0 +1,162 @@
from typing import cast
from agents import function_tool
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import GeneratedImage
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.chat.turn.models import ChatTurnContext
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import save_files
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.logger import setup_logger
logger = setup_logger()
@tool_accounting
def _image_generation_core(
run_context: RunContextWrapper[ChatTurnContext],
prompt: str,
shape: str,
image_generation_tool_instance: ImageGenerationTool,
) -> list[GeneratedImage]:
index = run_context.context.current_run_step
emitter = run_context.context.run_dependencies.emitter
# Emit start event
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolStart(type="image_generation_tool_start"),
)
)
# Prepare tool arguments
tool_args = {"prompt": prompt}
if shape != "square": # Only include shape if it's not the default
tool_args["shape"] = shape
# Run the actual image generation tool with heartbeat handling
generated_images: list[GeneratedImage] = []
heartbeat_count = 0
for tool_response in image_generation_tool_instance.run(
**tool_args # type: ignore[arg-type]
):
# Handle heartbeat responses
if tool_response.id == "image_generation_heartbeat":
# Emit heartbeat event for every iteration
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolHeartbeat(
type="image_generation_tool_heartbeat"
),
)
)
heartbeat_count += 1
logger.debug(f"Image generation heartbeat #{heartbeat_count}")
continue
# Process the tool response to get the generated images
if tool_response.id == "image_generation_response":
image_generation_responses = cast(
list[ImageGenerationResponse], tool_response.response
)
file_ids = save_files(
urls=[img.url for img in image_generation_responses if img.url],
base64_files=[
img.image_data
for img in image_generation_responses
if img.image_data
],
)
generated_images = [
GeneratedImage(
file_id=file_id,
url=img.url if img.url else build_frontend_file_url(file_id),
revised_prompt=img.revised_prompt,
)
for img, file_id in zip(image_generation_responses, file_ids)
]
break
if not generated_images:
raise RuntimeError("No images were generated")
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="Generating images",
purpose="Generating images",
reasoning="Generating images",
)
)
run_context.context.aggregated_context.global_iteration_responses.append(
IterationAnswer(
tool=image_generation_tool_instance.name,
tool_id=image_generation_tool_instance.id,
iteration_nr=run_context.context.current_run_step,
parallelization_nr=0,
question=prompt,
answer="",
reasoning="",
claims=[],
generated_images=generated_images,
additional_data={},
response_type=None,
data=None,
file_ids=None,
cited_documents={},
)
)
# Emit final result
emitter.emit(
Packet(
ind=index,
obj=ImageGenerationToolDelta(
type="image_generation_tool_delta", images=generated_images
),
)
)
return generated_images
# failure_error_function=None causes error to be re-raised instead of passing error
# message back to the LLM. This is needed for image_generation since we configure our agent
# to stop at this tool.
@function_tool(failure_error_function=None)
def image_generation_tool(
run_context: RunContextWrapper[ChatTurnContext], prompt: str, shape: str = "square"
) -> str:
"""
Generate an image from a text prompt using AI image generation models.
Args:
prompt: The text description of the image to generate
shape: The desired image shape - 'square', 'portrait', or 'landscape'
"""
image_generation_tool_instance = (
run_context.context.run_dependencies.image_generation_tool
)
assert image_generation_tool_instance is not None
generated_images: list[GeneratedImage] = _image_generation_core(
run_context, prompt, shape, image_generation_tool_instance
)
# We should stop after this tool is called, so it doesn't matter what it returns
return f"Successfully generated {len(generated_images)} images"

View File

@@ -0,0 +1,197 @@
from typing import cast
from agents import function_tool
from agents import RunContextWrapper
from onyx.agents.agent_search.dr.models import InferenceSection
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
from onyx.chat.models import LlmDoc
from onyx.chat.stop_signal_checker import is_connected
from onyx.chat.turn.models import ChatTurnContext
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.tools import get_tool_by_name
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
@tool_accounting
def _internal_search_core(
run_context: RunContextWrapper[ChatTurnContext],
queries: list[str],
search_tool: SearchTool,
) -> list[LlmDoc]:
"""Core internal search logic that can be tested with dependency injection"""
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolStart(
type="internal_search_tool_start", is_internet_search=False
),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta", queries=queries, documents=[]
),
)
)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Searching internally for information",
reasoning=f"I am now using Internal Search to gather information on {queries}",
)
)
def execute_single_query(query: str, parallelization_nr: int) -> list[LlmDoc]:
"""Execute a single query and return the retrieved documents as LlmDocs"""
retrieved_llm_docs_for_query: list[LlmDoc] = []
with get_session_with_current_tenant() as search_db_session:
for tool_response in search_tool.run(
query=query,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=search_db_session,
skip_query_analysis=True,
original_query=query,
),
):
if not is_connected(
run_context.context.chat_session_id,
run_context.context.run_dependencies.redis_client,
):
break
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
# TODO: just a heuristic to not overload context window -- carried over from existing DR flow
docs_to_feed_llm = 15
retrieved_sections: list[InferenceSection] = response.top_sections[
:docs_to_feed_llm
]
# Convert InferenceSections to LlmDocs for return value
retrieved_llm_docs_for_query = [
section_to_llm_doc(section) for section in retrieved_sections
]
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta",
queries=[],
documents=convert_inference_sections_to_search_docs(
retrieved_sections, is_internet=False
),
),
)
)
run_context.context.aggregated_context.cited_documents.extend(
retrieved_sections
)
run_context.context.aggregated_context.global_iteration_responses.append(
IterationAnswer(
tool=SearchTool.__name__,
tool_id=get_tool_by_name(
SearchTool.__name__,
run_context.context.run_dependencies.db_session,
).id,
iteration_nr=index,
parallelization_nr=parallelization_nr,
question=query,
reasoning=f"I am now using Internal Search to gather information on {query}",
answer="",
cited_documents={
i: inference_section
for i, inference_section in enumerate(
retrieved_sections
)
},
queries=[query],
)
)
break
return retrieved_llm_docs_for_query
# Execute all queries in parallel using run_functions_in_parallel
function_calls = [
FunctionCall(func=execute_single_query, args=(query, i))
for i, query in enumerate(queries)
]
search_results_dict = run_functions_in_parallel(function_calls)
# Aggregate all results from all queries
all_retrieved_docs: list[LlmDoc] = []
for result_id in search_results_dict:
retrieved_docs = search_results_dict[result_id]
if retrieved_docs:
all_retrieved_docs.extend(retrieved_docs)
return all_retrieved_docs
@function_tool
def internal_search_tool(
run_context: RunContextWrapper[ChatTurnContext], queries: list[str]
) -> str:
"""
Tool for searching over internal knowledge base from the user's connectors.
The queries will be searched over a vector database where a hybrid search will be performed.
Will return a combination of keyword and semantic search results.
---
## Decision boundary
- MUST call internal_search_tool if the user's query requires internal information, like
if it references "we" or "us" or "our" or "internal" or if it references
the organization the user works for.
## Usage hints
- Batch a list of natural-language queries per call.
- Generally try searching with some semantic queries and some keyword queries
to give the hybrid search the best chance of finding relevant results.
## Args
- queries (list[str]): The search queries.
## Returns (list of LlmDoc objects as string)
Each LlmDoc contains:
- document_id: Unique document identifier
- content: Full document content (combined from all chunks in the section)
- blurb: Text excerpt from the document
- semantic_identifier: Human-readable document name
- source_type: Type of document source (e.g., web, confluence, etc.)
- metadata: Additional document metadata
- updated_at: When document was last updated
- link: Primary URL to the source (may be None). Used for citations.
- source_links: Dictionary of URLs to the source
- match_highlights: Highlighted matching text snippets
"""
search_pipeline_instance = run_context.context.run_dependencies.search_pipeline
if search_pipeline_instance is None:
raise RuntimeError("Search tool not available in context")
# Call the core function
retrieved_docs = _internal_search_core(
run_context, queries, search_pipeline_instance
)
return str(retrieved_docs)

View File

@@ -0,0 +1,90 @@
import json
from typing import Any
from agents import function_tool
from agents import RunContextWrapper
from onyx.chat.turn.models import ChatTurnContext
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
from onyx.server.query_and_chat.streaming_models import CustomToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.logger import setup_logger
logger = setup_logger()
@tool_accounting
def _okta_profile_core(
run_context: RunContextWrapper[ChatTurnContext],
okta_profile_tool_instance: Any,
) -> dict[str, Any]:
"""Core Okta profile logic that can be tested with dependency injection"""
if okta_profile_tool_instance is None:
raise RuntimeError("Okta profile tool not available in context")
index = run_context.context.current_run_step
emitter = run_context.context.run_dependencies.emitter # type: ignore[union-attr]
# Emit start event
emitter.emit( # type: ignore[union-attr]
Packet(
ind=index,
obj=CustomToolStart(type="custom_tool_start", tool_name="Okta Profile"),
)
)
# Emit delta event for fetching profile
emitter.emit( # type: ignore[union-attr]
Packet(
ind=index,
obj=CustomToolDelta(
type="custom_tool_delta",
tool_name="Okta Profile",
response_type="text",
data="Fetching profile information...",
),
)
)
# Run the actual Okta profile tool
profile_data = None
for tool_response in okta_profile_tool_instance.run():
if tool_response.id == "okta_profile":
profile_data = tool_response.response
break
if profile_data is None:
raise RuntimeError("No profile data was retrieved from Okta")
# Emit final result
emitter.emit( # type: ignore[union-attr]
Packet(
ind=index,
obj=CustomToolDelta(
type="custom_tool_delta",
tool_name="Okta Profile",
response_type="json",
data=profile_data,
),
)
)
return profile_data
@function_tool
def okta_profile_tool(run_context: RunContextWrapper[ChatTurnContext]) -> str:
"""
Retrieve the current user's profile information from Okta.
This tool fetches user profile details including name, email, department,
location, title, manager, and other profile information from the Okta identity provider.
"""
# Get the Okta profile tool from context
okta_profile_tool_instance = run_context.context.run_dependencies.okta_profile_tool
# Call the core function
profile_data = _okta_profile_core(run_context, okta_profile_tool_instance)
return json.dumps(profile_data)

View File

@@ -0,0 +1,79 @@
import functools
import inspect
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import TypeVar
from agents import RunContextWrapper
from onyx.chat.turn.models import ChatTurnContext
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
F = TypeVar("F", bound=Callable)
def tool_accounting(func: F) -> F:
"""
Decorator that adds tool accounting functionality to tool functions.
Handles both sync and async functions automatically.
This decorator:
1. Increments the current_run_step index at the beginning
2. Emits a section end packet and increments current_run_step at the end
3. Ensures the cleanup happens even if an exception occurs
Args:
func: The function to decorate. Must take a RunContextWrapper[ChatTurnContext] as first argument.
Returns:
The decorated function with tool accounting functionality.
"""
@functools.wraps(func)
def wrapper(
run_context: RunContextWrapper[ChatTurnContext], *args: Any, **kwargs: Any
) -> Any:
# Increment current_run_step at the beginning
run_context.context.current_run_step += 1
try:
# Call the original function
result = func(run_context, *args, **kwargs)
# If it's a coroutine, we need to handle it differently
if inspect.iscoroutine(result):
# For async functions, we need to return a coroutine that handles the cleanup
async def async_wrapper() -> Any:
try:
return await result
finally:
_emit_section_end(run_context)
return async_wrapper()
else:
# For sync functions, emit cleanup immediately
_emit_section_end(run_context)
return result
except Exception:
# Always emit cleanup even if an exception occurred
_emit_section_end(run_context)
raise
return cast(F, wrapper)
def _emit_section_end(run_context: RunContextWrapper[ChatTurnContext]) -> None:
"""Helper function to emit section end packet and increment current_run_step."""
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SectionEnd(
type="section_end",
),
)
)
run_context.context.current_run_step += 1

View File

@@ -0,0 +1,312 @@
from typing import List
from typing import Optional
from agents import function_tool
from agents import RunContextWrapper
from pydantic import BaseModel
from onyx.agents.agent_search.dr.models import IterationAnswer
from onyx.agents.agent_search.dr.models import IterationInstructions
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
get_default_provider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.providers import (
WebSearchProvider,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_content,
)
from onyx.agents.agent_search.dr.sub_agents.web_search.utils import (
dummy_inference_section_from_internet_search_result,
)
from onyx.chat.turn.models import ChatTurnContext
from onyx.db.tools import get_tool_by_name
from onyx.server.query_and_chat.streaming_models import FetchToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SavedSearchDoc
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
from onyx.server.query_and_chat.streaming_models import SearchToolStart
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_implementations_v2.tool_accounting import tool_accounting
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
class WebSearchResult(BaseModel):
tag: str
title: str
link: str
snippet: str
author: Optional[str] = None
published_date: Optional[str] = None
class WebSearchResponse(BaseModel):
results: List[WebSearchResult]
class WebFetchResult(BaseModel):
tag: str
title: str
link: str
full_content: str
published_date: Optional[str] = None
class WebFetchResponse(BaseModel):
results: List[WebFetchResult]
def short_tag(link: str, i: int) -> str:
return f"{i+1}"
@tool_accounting
def _web_search_core(
run_context: RunContextWrapper[ChatTurnContext],
queries: list[str],
search_provider: WebSearchProvider,
) -> WebSearchResponse:
from onyx.utils.threadpool_concurrency import FunctionCall
index = run_context.context.current_run_step
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolStart(
type="internal_search_tool_start", is_internet_search=True
),
)
)
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=SearchToolDelta(
type="internal_search_tool_delta", queries=queries, documents=[]
),
)
)
queries_str = ", ".join(queries)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Searching the web for information",
reasoning=f"I am now using Web Search to gather information on {queries_str}",
)
)
# Search all queries in parallel
function_calls = [
FunctionCall(func=search_provider.search, args=(query,)) for query in queries
]
search_results_dict = run_functions_in_parallel(function_calls)
# Aggregate all results from all queries
all_hits = []
for result_id in search_results_dict:
hits = search_results_dict[result_id]
if hits:
all_hits.extend(hits)
# Convert hits to WebSearchResult objects
results = []
for i, r in enumerate(all_hits):
results.append(
WebSearchResult(
tag=short_tag(r.link, i),
title=r.title,
link=r.link,
snippet=r.snippet or "",
author=r.author,
published_date=(
r.published_date.isoformat() if r.published_date else None
),
)
)
# Create inference sections from search results and add to cited documents
inference_sections = [
dummy_inference_section_from_internet_search_result(r) for r in all_hits
]
run_context.context.aggregated_context.cited_documents.extend(inference_sections)
run_context.context.aggregated_context.global_iteration_responses.append(
IterationAnswer(
tool=WebSearchTool.__name__,
tool_id=get_tool_by_name(
WebSearchTool.__name__, run_context.context.run_dependencies.db_session
).id,
iteration_nr=index,
parallelization_nr=0,
question=queries_str,
reasoning=f"I am now using Web Search to gather information on {queries_str}",
answer="",
cited_documents={
i: inference_section
for i, inference_section in enumerate(inference_sections)
},
claims=[],
queries=queries,
)
)
return WebSearchResponse(results=results)
@function_tool
def web_search_tool(
run_context: RunContextWrapper[ChatTurnContext], queries: list[str]
) -> str:
"""
Tool for searching the public internet. Useful for up to date information on PUBLIC knowledge.
---
## Decision boundary
- You MUST call `web_search_tool` to discover sources when the request involves:
- Fresh/unstable info (news, prices, laws, schedules, product specs, scores, exchange rates).
- Recommendations, or any query where the specific sources matter.
- Verifiable claims, quotes, or citations.
- After ANY successful `web_search_tool` call that yields candidate URLs, you MUST call
`web_fetch_tool` on the selected URLs BEFORE answering. Do NOT answer from snippets.
## When NOT to use
- Casual chat, rewriting/summarizing user-provided text, or translation.
- When the user already provided URLs (go straight to `web_fetch_tool`).
## Usage hints
- Batch a list of natural-language queries per call.
- Prefer searches for distinct intents; then batch-fetch best URLs.
- Deduplicate domains/near-duplicates. Prefer recent, authoritative sources.
## Args
- queries (list[str]): The search queries.
## Returns (JSON string)
{
"results": [
{
"tag": "short_ref",
"title": "...",
"link": "https://...",
"author": "...",
"published_date": "2025-10-01T12:34:56Z"
// intentionally NO full content
}
]
}
"""
search_provider = get_default_provider()
if search_provider is None:
raise ValueError("No search provider found")
response = _web_search_core(run_context, queries, search_provider)
return response.model_dump_json()
@tool_accounting
def _web_fetch_core(
run_context: RunContextWrapper[ChatTurnContext],
urls: List[str],
search_provider: WebSearchProvider,
) -> WebFetchResponse:
# TODO: Find better way to track index that isn't so implicit
# based on number of tool calls
index = run_context.context.current_run_step
# Create SavedSearchDoc objects from URLs for the FetchToolStart event
saved_search_docs = [SavedSearchDoc.from_url(url) for url in urls]
run_context.context.run_dependencies.emitter.emit(
Packet(
ind=index,
obj=FetchToolStart(type="fetch_tool_start", documents=saved_search_docs),
)
)
docs = search_provider.contents(urls)
out = []
for i, d in enumerate(docs):
out.append(
WebFetchResult(
tag=short_tag(d.link, i), # <-- add a tag
title=d.title,
link=d.link,
full_content=d.full_content,
published_date=(
d.published_date.isoformat() if d.published_date else None
),
)
)
run_context.context.iteration_instructions.append(
IterationInstructions(
iteration_nr=index,
plan="plan",
purpose="Fetching content from URLs",
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
)
)
run_context.context.aggregated_context.global_iteration_responses.append(
IterationAnswer(
# TODO: For now, we're using the web_search_tool_name since the web_fetch_tool_name is not a built-in tool
tool=WebSearchTool.__name__,
tool_id=get_tool_by_name(
WebSearchTool.__name__, run_context.context.run_dependencies.db_session
).id,
iteration_nr=index,
parallelization_nr=0,
question=f"Fetch content from URLs: {', '.join(urls)}",
reasoning=f"I am now using Web Fetch to gather information on {', '.join(urls)}",
answer="",
cited_documents={
i: dummy_inference_section_from_internet_content(d)
for i, d in enumerate(docs)
},
claims=[],
is_web_fetch=True,
)
)
return WebFetchResponse(results=out)
@function_tool
def web_fetch_tool(
run_context: RunContextWrapper[ChatTurnContext], urls: List[str]
) -> str:
"""
Tool for fetching and extracting full content from web pages.
---
## Decision boundary
- You MUST use `web_fetch_tool` before quoting, citing, or relying on page content.
- Use it whenever you already have URLs (from the user or from `web_search_tool`).
- Do NOT answer questions based on search snippets alone.
## When NOT to use
- If you do not yet have URLs (search first).
## Usage hints
- Avoid many tiny calls; batch URLs (120) in one request.
- Prefer primary, recent, and reputable sources.
- If PDFs/long docs appear, still fetch; you may summarize sections explicitly.
## Args
- urls (List[str]): Absolute URLs to retrieve.
## Returns (JSON string)
{
"results": [
{
"tag": "short_ref",
"title": "...",
"link": "https://...",
"full_content": "...",
"published_date": "2025-10-01T12:34:56Z"
}
]
}
"""
search_provider = get_default_provider()
if search_provider is None:
raise ValueError("No search provider found")
response = _web_fetch_core(run_context, urls, search_provider)
return response.model_dump_json()

View File

@@ -283,7 +283,7 @@ def run_functions_in_parallel(
return results
def run_async_sync(coro: Awaitable[T]) -> T:
def run_async_sync_no_cancel(coro: Awaitable[T]) -> T:
"""
async-to-sync converter. Basically just executes asyncio.run in a separate thread.
Which is probably somehow inefficient or not ideal but fine for now.

View File

@@ -49,7 +49,7 @@ msal==1.28.0
nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.99.5
openai==1.107.1
openpyxl==3.1.5
passlib==1.7.4
playwright==1.55.0
@@ -105,5 +105,6 @@ sendgrid==6.11.0
voyageai==0.2.3
cohere==5.6.1
exa_py==1.15.4
braintrust==0.2.6
braintrust-langchain==0.0.4
braintrust[openai-agents]==0.2.6
braintrust-langchain==0.0.4
openai-agents==0.3.3

View File

@@ -34,3 +34,4 @@ types-retry==0.9.9.3
types-setuptools==68.0.0.3
types-urllib3==1.26.25.11
voyageai==0.2.3
ipykernel==6.29.5

View File

@@ -4,7 +4,7 @@ cohere==5.6.1
fastapi==0.116.1
google-cloud-aiplatform==1.58.0
numpy==1.26.4
openai==1.99.5
openai==1.107.1
pydantic==2.11.7
retry==0.9.2
safetensors==0.5.3

View File

@@ -16,6 +16,9 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
def run_jobs() -> None:
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
use_lightweight = True
# command setup
cmd_worker_primary = [
"celery",
@@ -45,20 +48,6 @@ def run_jobs() -> None:
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup",
]
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
]
cmd_worker_docprocessing = [
"celery",
"-A",
@@ -72,45 +61,6 @@ def run_jobs() -> None:
"--queues=docprocessing",
]
cmd_worker_user_files = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"--queues=user_file_processing,user_file_project_sync",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"--queues=monitoring",
]
cmd_worker_kg_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.kg_processing",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=kg_processing@%n",
"--queues=kg_processing",
]
cmd_worker_docfetching = [
"celery",
"-A",
@@ -132,6 +82,84 @@ def run_jobs() -> None:
"--loglevel=INFO",
]
# Prepare background worker commands based on mode
if use_lightweight:
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
cmd_worker_background = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync",
]
background_workers = [("BACKGROUND", cmd_worker_background)]
else:
print("Starting workers in STANDARD mode (separate background workers)")
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning",
]
cmd_worker_kg_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.kg_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=kg_processing@%n",
"-Q",
"kg_processing",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation",
]
background_workers = [
("HEAVY", cmd_worker_heavy),
("KG_PROCESSING", cmd_worker_kg_processing),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
]
# spawn processes
worker_primary_process = subprocess.Popen(
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
@@ -141,10 +169,6 @@ def run_jobs() -> None:
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_heavy_process = subprocess.Popen(
cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_docprocessing_process = subprocess.Popen(
cmd_worker_docprocessing,
stdout=subprocess.PIPE,
@@ -152,27 +176,6 @@ def run_jobs() -> None:
text=True,
)
worker_user_file_process = subprocess.Popen(
cmd_worker_user_files,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_monitoring_process = subprocess.Popen(
cmd_worker_monitoring,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_kg_processing_process = subprocess.Popen(
cmd_worker_kg_processing,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
@@ -184,6 +187,14 @@ def run_jobs() -> None:
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Spawn background worker processes based on mode
background_processes = []
for name, cmd in background_workers:
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
background_processes.append((name, process))
# monitor threads
worker_primary_thread = threading.Thread(
target=monitor_process, args=("PRIMARY", worker_primary_process)
@@ -191,47 +202,40 @@ def run_jobs() -> None:
worker_light_thread = threading.Thread(
target=monitor_process, args=("LIGHT", worker_light_process)
)
worker_heavy_thread = threading.Thread(
target=monitor_process, args=("HEAVY", worker_heavy_process)
)
worker_docprocessing_thread = threading.Thread(
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
)
worker_user_file_thread = threading.Thread(
target=monitor_process,
args=("USER_FILE_PROCESSING", worker_user_file_process),
)
worker_monitoring_thread = threading.Thread(
target=monitor_process, args=("MONITORING", worker_monitoring_process)
)
worker_kg_processing_thread = threading.Thread(
target=monitor_process, args=("KG_PROCESSING", worker_kg_processing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
# Create monitor threads for background workers
background_threads = []
for name, process in background_processes:
thread = threading.Thread(target=monitor_process, args=(name, process))
background_threads.append(thread)
# Start all threads
worker_primary_thread.start()
worker_light_thread.start()
worker_heavy_thread.start()
worker_docprocessing_thread.start()
worker_user_file_thread.start()
worker_monitoring_thread.start()
worker_kg_processing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
for thread in background_threads:
thread.start()
# Wait for all threads
worker_primary_thread.join()
worker_light_thread.join()
worker_heavy_thread.join()
worker_docprocessing_thread.join()
worker_user_file_thread.join()
worker_monitoring_thread.join()
worker_kg_processing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()
for thread in background_threads:
thread.join()
if __name__ == "__main__":
run_jobs()

Some files were not shown because too many files have changed in this diff Show More