Compare commits

..

2 Commits

Author SHA1 Message Date
Weves
ca3db17b08 add restart 2025-12-17 12:48:46 -08:00
Weves
ffd13b1104 dump scripts 2025-12-17 12:48:46 -08:00
311 changed files with 10260 additions and 12409 deletions

View File

@@ -4,14 +4,7 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read

View File

@@ -4,14 +4,7 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read

View File

@@ -8,6 +8,10 @@ repos:
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-run
name: Check lazy imports
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
- id: uv-sync
args: ["--locked", "--all-extras"]
- id: uv-lock
@@ -28,10 +32,6 @@ repos:
name: uv-export model_server.txt
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-run
name: Check lazy imports
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy

View File

@@ -1,27 +0,0 @@
"""add last refreshed at mcp server
Revision ID: 2a391f840e85
Revises: 4cebcbc9b2ae
Create Date: 2025-12-06 15:19:59.766066
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembi.
revision = "2a391f840e85"
down_revision = "4cebcbc9b2ae"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"mcp_server",
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("mcp_server", "last_refreshed_at")

View File

@@ -1,27 +0,0 @@
"""add tab_index to tool_call
Revision ID: 4cebcbc9b2ae
Revises: a1b2c3d4e5f6
Create Date: 2025-12-16
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4cebcbc9b2ae"
down_revision = "a1b2c3d4e5f6"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"tool_call",
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
)
def downgrade() -> None:
op.drop_column("tool_call", "tab_index")

View File

@@ -1,49 +0,0 @@
"""add license table
Revision ID: a1b2c3d4e5f6
Revises: a01bf2971c5d
Create Date: 2025-12-04 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "a01bf2971c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"license",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("license_data", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
# Singleton pattern - only ever one row in this table
op.create_index(
"idx_license_singleton",
"license",
[sa.text("(true)")],
unique=True,
)
def downgrade() -> None:
op.drop_index("idx_license_singleton", table_name="license")
op.drop_table("license")

View File

@@ -1,27 +0,0 @@
"""Remove fast_default_model_name from llm_provider
Revision ID: a2b3c4d5e6f7
Revises: 2a391f840e85
Create Date: 2024-12-17
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a2b3c4d5e6f7"
down_revision = "2a391f840e85"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("llm_provider", "fast_default_model_name")
def downgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("fast_default_model_name", sa.String(), nullable=True),
)

View File

@@ -1,46 +0,0 @@
"""Drop milestone table
Revision ID: b8c9d0e1f2a3
Revises: a2b3c4d5e6f7
Create Date: 2025-12-18
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b8c9d0e1f2a3"
down_revision = "a2b3c4d5e6f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("milestone")
def downgrade() -> None:
op.create_table(
"milestone",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=True),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("event_type", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)

View File

@@ -1,52 +0,0 @@
"""add_deep_research_tool
Revision ID: c1d2e3f4a5b6
Revises: b8c9d0e1f2a3
Create Date: 2025-12-18 16:00:00.000000
"""
from alembic import op
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c1d2e3f4a5b6"
down_revision = "b8c9d0e1f2a3"
branch_labels = None
depends_on = None
DEEP_RESEARCH_TOOL = {
"name": RESEARCH_AGENT_DB_NAME,
"display_name": "Research Agent",
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
"in_code_tool_id": "ResearchAgent",
}
def upgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
"""
),
DEEP_RESEARCH_TOOL,
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
DELETE FROM tool
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
)

View File

@@ -1,278 +0,0 @@
"""Database and cache operations for the license table."""
from datetime import datetime
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
LICENSE_METADATA_KEY = "license:metadata"
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
# -----------------------------------------------------------------------------
# Database CRUD Operations
# -----------------------------------------------------------------------------
def get_license(db_session: Session) -> License | None:
"""
Get the current license (singleton pattern - only one row).
Args:
db_session: Database session
Returns:
License object if exists, None otherwise
"""
return db_session.execute(select(License)).scalars().first()
def upsert_license(db_session: Session, license_data: str) -> License:
"""
Insert or update the license (singleton pattern).
Args:
db_session: Database session
license_data: Base64-encoded signed license blob
Returns:
The created or updated License object
"""
existing = get_license(db_session)
if existing:
existing.license_data = license_data
db_session.commit()
db_session.refresh(existing)
logger.info("License updated")
return existing
new_license = License(license_data=license_data)
db_session.add(new_license)
db_session.commit()
db_session.refresh(new_license)
logger.info("License created")
return new_license
def delete_license(db_session: Session) -> bool:
"""
Delete the current license.
Args:
db_session: Database session
Returns:
True if deleted, False if no license existed
"""
existing = get_license(db_session)
if existing:
db_session.delete(existing)
db_session.commit()
logger.info("License deleted")
return True
return False
# -----------------------------------------------------------------------------
# Seat Counting
# -----------------------------------------------------------------------------
def get_used_seats(tenant_id: str | None = None) -> int:
"""
Get current seat usage.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (includes both Onyx UI users
and Slack users who have been converted to Onyx users).
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
return get_tenant_count(tenant_id or get_current_tenant_id())
else:
# Self-hosted: count all active users (Onyx + converted Slack users)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
result = db_session.execute(
select(func.count()).select_from(User).where(User.is_active) # type: ignore
)
return result.scalar() or 0
# -----------------------------------------------------------------------------
# Redis Cache Operations
# -----------------------------------------------------------------------------
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
"""
Get license metadata from Redis cache.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if cached, None otherwise
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_replica_client(tenant_id=tenant)
cached = redis_client.get(LICENSE_METADATA_KEY)
if cached:
try:
cached_str: str
if isinstance(cached, bytes):
cached_str = cached.decode("utf-8")
else:
cached_str = str(cached)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
return None
def invalidate_license_cache(tenant_id: str | None = None) -> None:
"""
Invalidate the license metadata cache (not the license itself).
This deletes the cached LicenseMetadata from Redis. The actual license
in the database is not affected. Redis delete is idempotent - if the
key doesn't exist, this is a no-op.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
redis_client.delete(LICENSE_METADATA_KEY)
logger.info("License cache invalidated")
def update_license_cache(
payload: LicensePayload,
source: LicenseSource | None = None,
grace_period_end: datetime | None = None,
tenant_id: str | None = None,
) -> LicenseMetadata:
"""
Update the Redis cache with license metadata.
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
1. Frontend needs status to show appropriate UI/banners
2. Caching avoids repeated DB + crypto verification on every request
3. Status enforcement happens at the feature level, not here
Args:
payload: Verified license payload
source: How the license was obtained
grace_period_end: Optional grace period end time
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
The cached LicenseMetadata
"""
from ee.onyx.utils.license import get_license_status
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
used_seats = get_used_seats(tenant)
status = get_license_status(payload, grace_period_end)
metadata = LicenseMetadata(
tenant_id=payload.tenant_id,
organization_name=payload.organization_name,
seats=payload.seats,
used_seats=used_seats,
plan_type=payload.plan_type,
issued_at=payload.issued_at,
expires_at=payload.expires_at,
grace_period_end=grace_period_end,
status=status,
source=source,
stripe_subscription_id=payload.stripe_subscription_id,
)
redis_client.setex(
LICENSE_METADATA_KEY,
LICENSE_CACHE_TTL_SECONDS,
metadata.model_dump_json(),
)
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
return metadata
def refresh_license_cache(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Refresh the license cache from the database.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
from ee.onyx.utils.license import verify_license_signature
license_record = get_license(db_session)
if not license_record:
invalidate_license_cache(tenant_id)
return None
try:
payload = verify_license_signature(license_record.license_data)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:
logger.error(f"Failed to verify license during cache refresh: {e}")
invalidate_license_cache(tenant_id)
return None
def get_license_metadata(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Get license metadata, using cache if available.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
# Try cache first
cached = get_cached_license_metadata(tenant_id)
if cached:
return cached
# Refresh from database
return refresh_license_cache(db_session, tenant_id)

View File

@@ -14,7 +14,6 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
@@ -140,8 +139,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
include_router_with_global_prefix_prepended(application, usage_export_router)
# License management
include_router_with_global_prefix_prepended(application, license_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -1,246 +0,0 @@
"""License API endpoints."""
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license_metadata
from ee.onyx.db.license import invalidate_license_cache
from ee.onyx.db.license import refresh_license_cache
from ee.onyx.db.license import update_license_cache
from ee.onyx.db.license import upsert_license
from ee.onyx.server.license.models import LicenseResponse
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/license")
@router.get("")
async def get_license_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""Get current license status and seat usage."""
metadata = get_license_metadata(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.get("/seats")
async def get_seat_usage(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUsageResponse:
"""Get detailed seat usage information."""
metadata = get_license_metadata(db_session)
if not metadata:
return SeatUsageResponse(
total_seats=0,
used_seats=0,
available_seats=0,
)
return SeatUsageResponse(
total_seats=metadata.seats,
used_seats=metadata.used_seats,
available_seats=max(0, metadata.seats - metadata.used_seats),
)
@router.post("/fetch")
async def fetch_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
Fetch license from control plane.
Used after Stripe checkout completion to retrieve the new license.
"""
tenant_id = get_current_tenant_id()
try:
token = generate_data_plane_token()
except ValueError as e:
logger.error(f"Failed to generate data plane token: {e}")
raise HTTPException(
status_code=500, detail="Authentication configuration error"
)
try:
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict) or "license" not in data:
raise HTTPException(
status_code=502, detail="Invalid response from control plane"
)
license_data = data["license"]
if not license_data:
raise HTTPException(status_code=404, detail="No license found")
# Verify signature before persisting
payload = verify_license_signature(license_data)
# Verify the fetched license is for this tenant
if payload.tenant_id != tenant_id:
logger.error(
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
)
raise HTTPException(
status_code=400,
detail="License tenant ID mismatch - control plane returned wrong license",
)
# Persist to DB and update cache atomically
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseResponse(success=True, license=payload)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 502
logger.error(f"Control plane returned error: {status_code}")
raise HTTPException(
status_code=status_code,
detail="Failed to fetch license from control plane",
)
except ValueError as e:
logger.error(f"License verification failed: {type(e).__name__}")
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
logger.exception("Failed to fetch license from control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
@router.post("/upload")
async def upload_license(
license_file: UploadFile = File(...),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
Upload a license file manually.
Used for air-gapped deployments where control plane is not accessible.
"""
try:
content = await license_file.read()
license_data = content.decode("utf-8").strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tenant_id = get_current_tenant_id()
if payload.tenant_id != tenant_id:
raise HTTPException(
status_code=400,
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseUploadResponse(
success=True,
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
)
@router.post("/refresh")
async def refresh_license_cache_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
Force refresh the license cache from the database.
Useful after manual database changes or to verify license validity.
"""
metadata = refresh_license_cache(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.delete("")
async def delete_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, bool]:
"""
Delete the current license.
Admin only - removes license and invalidates cache.
"""
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
try:
invalidate_license_cache()
except Exception as cache_error:
logger.warning(f"Failed to invalidate license cache: {cache_error}")
deleted = db_delete_license(db_session)
return {"deleted": deleted}

View File

@@ -1,92 +0,0 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from onyx.server.settings.models import ApplicationStatus
class PlanType(str, Enum):
MONTHLY = "monthly"
ANNUAL = "annual"
class LicenseSource(str, Enum):
AUTO_FETCH = "auto_fetch"
MANUAL_UPLOAD = "manual_upload"
class LicensePayload(BaseModel):
"""The payload portion of a signed license."""
version: str
tenant_id: str
organization_name: str | None = None
issued_at: datetime
expires_at: datetime
seats: int
plan_type: PlanType
billing_cycle: str | None = None
grace_period_days: int = 30
stripe_subscription_id: str | None = None
stripe_customer_id: str | None = None
class LicenseData(BaseModel):
"""Full signed license structure."""
payload: LicensePayload
signature: str
class LicenseMetadata(BaseModel):
"""Cached license metadata stored in Redis."""
tenant_id: str
organization_name: str | None = None
seats: int
used_seats: int
plan_type: PlanType
issued_at: datetime
expires_at: datetime
grace_period_end: datetime | None = None
status: ApplicationStatus
source: LicenseSource | None = None
stripe_subscription_id: str | None = None
class LicenseStatusResponse(BaseModel):
"""Response for license status API."""
has_license: bool
seats: int = 0
used_seats: int = 0
plan_type: PlanType | None = None
issued_at: datetime | None = None
expires_at: datetime | None = None
grace_period_end: datetime | None = None
status: ApplicationStatus | None = None
source: LicenseSource | None = None
class LicenseResponse(BaseModel):
"""Response after license fetch/upload."""
success: bool
message: str | None = None
license: LicensePayload | None = None
class LicenseUploadResponse(BaseModel):
"""Response after license upload."""
success: bool
message: str | None = None
class SeatUsageResponse(BaseModel):
"""Response for seat usage API."""
total_seats: int
used_seats: int
available_seats: int

View File

@@ -20,7 +20,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
@@ -158,7 +158,7 @@ def handle_send_message_simple_with_history(
persona_id=req.persona_id,
)
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,

View File

@@ -45,7 +45,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRe
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
@@ -269,6 +269,7 @@ def configure_default_api_keys(db_session: Session) -> None:
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_configurations=[
ModelConfigurationUpsertRequest(
name=name,
@@ -295,6 +296,7 @@ def configure_default_api_keys(db_session: Session) -> None:
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,
@@ -560,11 +562,17 @@ async def assign_tenant_to_user(
try:
add_users_to_tenant([email], tenant_id)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=email,
event=MilestoneRecordType.TENANT_CREATED,
)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")

View File

@@ -249,17 +249,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
)
raise
# Remove from invited users list since they've accepted
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
invited_users = get_invited_users()
if email in invited_users:
invited_users.remove(email)
write_invited_users(invited_users)
logger.info(f"Removed {email} from invited users list after acceptance")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def deny_user_invite(email: str, tenant_id: str) -> None:
"""

View File

@@ -1,126 +0,0 @@
"""RSA-4096 license signature verification utilities."""
import base64
import json
import os
from datetime import datetime
from datetime import timezone
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from ee.onyx.server.license.models import LicenseData
from ee.onyx.server.license.models import LicensePayload
from onyx.server.settings.models import ApplicationStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
# RSA-4096 Public Key for license verification
# Load from environment variable - key is generated on the control plane
# In production, inject via Kubernetes secrets or secrets manager
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
def _get_public_key() -> RSAPublicKey:
"""Load the public key from environment variable."""
if not LICENSE_PUBLIC_KEY_PEM:
raise ValueError(
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
"License verification requires the control plane public key."
)
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
if not isinstance(key, RSAPublicKey):
raise ValueError("Expected RSA public key")
return key
def verify_license_signature(license_data: str) -> LicensePayload:
"""
Verify RSA-4096 signature and return payload if valid.
Args:
license_data: Base64-encoded JSON containing payload and signature
Returns:
LicensePayload if signature is valid
Raises:
ValueError: If license data is invalid or signature verification fails
"""
try:
# Decode the license data
decoded = json.loads(base64.b64decode(license_data))
license_obj = LicenseData(**decoded)
payload_json = json.dumps(
license_obj.payload.model_dump(mode="json"), sort_keys=True
)
signature_bytes = base64.b64decode(license_obj.signature)
# Verify signature using PSS padding (modern standard)
public_key = _get_public_key()
public_key.verify(
signature_bytes,
payload_json.encode(),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
return license_obj.payload
except InvalidSignature:
logger.error("License signature verification failed")
raise ValueError("Invalid license signature")
except json.JSONDecodeError:
logger.error("Failed to decode license JSON")
raise ValueError("Invalid license format: not valid JSON")
except (ValueError, KeyError, TypeError) as e:
logger.error(f"License data validation error: {type(e).__name__}")
raise ValueError(f"Invalid license format: {type(e).__name__}")
except Exception:
logger.exception("Unexpected error during license verification")
raise ValueError("License verification failed: unexpected error")
def get_license_status(
payload: LicensePayload,
grace_period_end: datetime | None = None,
) -> ApplicationStatus:
"""
Determine current license status based on expiry.
Args:
payload: The verified license payload
grace_period_end: Optional grace period end datetime
Returns:
ApplicationStatus indicating current license state
"""
now = datetime.now(timezone.utc)
# Check if grace period has expired
if grace_period_end and now > grace_period_end:
return ApplicationStatus.GATED_ACCESS
# Check if license has expired
if now > payload.expires_at:
if grace_period_end and now <= grace_period_end:
return ApplicationStatus.GRACE_PERIOD
return ApplicationStatus.GATED_ACCESS
# License is valid
return ApplicationStatus.ACTIVE
def is_license_valid(payload: LicensePayload) -> bool:
"""Check if a license is currently valid (not expired)."""
now = datetime.now(timezone.utc)
return now <= payload.expires_at

View File

@@ -117,7 +117,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.timing import log_function_time
@@ -653,11 +653,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
event=MilestoneRecordType.USER_SIGNED_UP,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
event_type = (
MilestoneRecordType.USER_SIGNED_UP
if user_count == 1
else MilestoneRecordType.MULTIPLE_USERS
)
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=event_type,
properties=None,
db_session=db_session,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -45,7 +45,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -109,7 +108,6 @@ from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -549,12 +547,6 @@ def check_indexing_completion(
)
db_session.commit()
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=tenant_id,
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
)
# Clear repeated error state on success
if cc_pair.in_repeated_error_state:
cc_pair.in_repeated_error_state = False

View File

@@ -127,6 +127,12 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
f"available, setting to first model in list: {available_models[0]}"
)
default_provider.default_model_name = available_models[0]
if default_provider.fast_default_model_name not in available_models:
task_logger.info(
f"Fast default model {default_provider.fast_default_model_name} "
f"not available, setting to first model in list: {available_models[0]}"
)
default_provider.fast_default_model_name = available_models[0]
db_session.commit()
if added_models or removed_models:

View File

@@ -1,6 +1,7 @@
import sys
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -20,6 +21,7 @@ from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -30,8 +32,11 @@ from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
@@ -44,16 +49,34 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
@@ -249,6 +272,583 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >2 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
)
if index_attempt_start.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
is_primary = (
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
)
from_beginning = index_attempt_start.from_beginning
has_successful_attempt = (
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
credential_id=db_credential.id,
source=db_connector.source,
earliest_index_time=(
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
),
from_beginning=from_beginning,
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary=is_primary,
should_fetch_permissions_during_indexing=(
index_attempt_start.connector_credential_pair.access_type
== AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
ctx.earliest_index_time
if ctx.from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=ctx.cc_pair_id,
earliest_index=ctx.earliest_index_time,
search_settings=index_attempt_start.search_settings,
db_session=db_session_temp,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
attempt_id=index_attempt_id,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=ctx.should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint, _ = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# save the initial checkpoint to have a proper record of the
# "last used checkpoint"
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = (
defaultdict(list)
)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# NOTE: this progress callback runs on every loop. We've seen cases
# where we loop many times with no new documents and eventually time
# out, so only doing the callback after indexing isn't sufficient.
callback.progress("_run_indexing", 0)
# TODO: should we move this into the above callback instead?
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
batch_description = []
# Generate an ID that can be used to correlate activity between here
# and the embedding model server
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
index_attempt_md.structured_id = (
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
)
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_md,
)
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
request_id=index_attempt_md.request_id,
adapter=adapter,
)
batch_num += 1
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_current_tenant() as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_current_tenant() as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_current_tenant() as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"current_docs_indexed": document_count,
"current_chunks_indexed": chunk_count,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_current_tenant() as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if ctx.is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
else:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
memory_tracer.stop()
# we know index attempt is successful (at least partially) at this point,
# all other cases have been short-circuited
elapsed_time = time.monotonic() - start_time
with get_session_with_current_tenant() as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session_temp,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"failures={total_failures} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=window_end,
)
if ctx.should_fetch_permissions_during_indexing:
mark_cc_pair_as_permissions_synced(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
start_time=window_end,
)
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,

View File

@@ -0,0 +1,64 @@
"""
Module for handling chat-related milestone tracking and telemetry.
"""
from sqlalchemy.orm import Session
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import User
from onyx.utils.telemetry import mt_cloud_telemetry
def process_multi_assistant_milestone(
user: User | None,
assistant_id: int,
tenant_id: str,
db_session: Session,
) -> None:
"""
Process the multi-assistant milestone for a user.
This function:
1. Creates or retrieves the multi-assistant milestone
2. Updates the milestone with the current assistant usage
3. Checks if the milestone was just achieved
4. Sends telemetry if the milestone was just hit
Args:
user: The user for whom to process the milestone (can be None for anonymous users)
assistant_id: The ID of the assistant being used
tenant_id: The current tenant ID
db_session: Database session for queries
"""
# Create or retrieve the multi-assistant milestone
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
db_session=db_session,
)
# Update the milestone with the current assistant usage
update_user_assistant_milestone(
milestone=multi_assistant_milestone,
user_id=str(user.id) if user else NO_AUTH_USER_ID,
assistant_id=assistant_id,
db_session=db_session,
)
# Check if the milestone was just achieved
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
milestone=multi_assistant_milestone,
db_session=db_session,
)
# Send telemetry if the milestone was just hit
if just_hit_multi_assistant_milestone:
mt_cloud_telemetry(
distinct_id=tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
properties=None,
)

View File

@@ -1,4 +1,3 @@
import threading
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
@@ -19,14 +18,9 @@ class ChatStateContainer:
This container holds the partial state that can be saved to the database
if the generation is stopped by the user or completes normally.
Thread-safe: All write operations are protected by a lock to ensure safe
concurrent access from multiple threads. For thread-safe reads, use the
getter methods. Direct attribute access is not thread-safe.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self.tool_calls: list[ToolCallInfo] = []
self.reasoning_tokens: str | None = None
self.answer_tokens: str | None = None
@@ -37,53 +31,23 @@ class ChatStateContainer:
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
"""Add a tool call to the accumulated state."""
with self._lock:
self.tool_calls.append(tool_call)
self.tool_calls.append(tool_call)
def set_reasoning_tokens(self, reasoning: str | None) -> None:
"""Set the reasoning tokens from the final answer generation."""
with self._lock:
self.reasoning_tokens = reasoning
self.reasoning_tokens = reasoning
def set_answer_tokens(self, answer: str | None) -> None:
"""Set the answer tokens from the final answer generation."""
with self._lock:
self.answer_tokens = answer
self.answer_tokens = answer
def set_citation_mapping(self, citation_to_doc: dict[int, Any]) -> None:
"""Set the citation mapping from citation processor."""
with self._lock:
self.citation_to_doc = citation_to_doc
self.citation_to_doc = citation_to_doc
def set_is_clarification(self, is_clarification: bool) -> None:
"""Set whether this turn is a clarification question."""
with self._lock:
self.is_clarification = is_clarification
def get_answer_tokens(self) -> str | None:
"""Thread-safe getter for answer_tokens."""
with self._lock:
return self.answer_tokens
def get_reasoning_tokens(self) -> str | None:
"""Thread-safe getter for reasoning_tokens."""
with self._lock:
return self.reasoning_tokens
def get_tool_calls(self) -> list[ToolCallInfo]:
"""Thread-safe getter for tool_calls (returns a copy)."""
with self._lock:
return self.tool_calls.copy()
def get_citation_to_doc(self) -> dict[int, SearchDoc]:
"""Thread-safe getter for citation_to_doc (returns a copy)."""
with self._lock:
return self.citation_to_doc.copy()
def get_is_clarification(self) -> bool:
"""Thread-safe getter for is_clarification."""
with self._lock:
return self.is_clarification
self.is_clarification = is_clarification
def run_chat_llm_with_state_containers(

View File

@@ -49,10 +49,8 @@ from onyx.llm.override_models import LLMOverride
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
@@ -731,38 +729,3 @@ def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) ->
if message.message_type == MessageType.ASSISTANT:
return message.is_clarification
return False
def create_tool_call_failure_messages(
tool_call: ToolCallKickoff, token_counter: Callable[[str], int]
) -> list[ChatMessageSimple]:
"""Create ChatMessageSimple objects for a failed tool call.
Creates two messages:
1. The tool call message itself
2. A failure response message indicating the tool call failed
Args:
tool_call: The ToolCallKickoff object representing the failed tool call
token_counter: Function to count tokens in a message string
Returns:
List containing two ChatMessageSimple objects: tool call message and failure response
"""
tool_call_msg = ChatMessageSimple(
message=tool_call.to_msg_str(),
token_count=token_counter(tool_call.to_msg_str()),
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
failure_response_msg = ChatMessageSimple(
message=TOOL_CALL_FAILURE_PROMPT,
token_count=token_counter(TOOL_CALL_FAILURE_PROMPT),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
return [tool_call_msg, failure_response_msg]

View File

@@ -1,12 +1,15 @@
import json
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import run_llm_step
from onyx.chat.llm_step import TOOL_CALL_MSG_ARGUMENTS
from onyx.chat.llm_step import TOOL_CALL_MSG_FUNC_NAME
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import LlmStepResult
@@ -322,6 +325,7 @@ def run_llm_loop(
# Pass the total budget to construct_message_history, which will handle token allocation
available_tokens = llm.config.max_input_tokens
tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO
collected_tool_calls: list[ToolCallInfo] = []
# Initialize gathered_documents with project files if present
gathered_documents: list[SearchDoc] | None = (
list(project_citation_mapping.values())
@@ -339,8 +343,12 @@ def run_llm_loop(
has_called_search_tool: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
reasoning_cycles = 0
current_tool_call_index = (
0 # TODO: just use the cycle count after parallel tool calls are supported
)
for llm_cycle_count in range(MAX_LLM_CYCLES):
if forced_tool_id:
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
@@ -437,13 +445,12 @@ def run_llm_loop(
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
# It also pre-processes the tool calls in preparation for running them
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
step_generator = run_llm_step(
history=truncated_message_history,
tool_definitions=[tool.tool_definition() for tool in final_tools],
tool_choice=tool_choice,
llm=llm,
turn_index=llm_cycle_count + reasoning_cycles,
turn_index=current_tool_call_index,
citation_processor=citation_processor,
state_container=state_container,
# The rich docs representation is passed in so that when yielding the answer, it can also
@@ -452,8 +459,18 @@ def run_llm_loop(
final_documents=gathered_documents,
user_identity=user_identity,
)
if has_reasoned:
reasoning_cycles += 1
# Consume the generator, emitting packets and capturing the final result
while True:
try:
packet = next(step_generator)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, current_tool_call_index = e.value
break
# Type narrowing: generator always returns a result, so this can't be None
llm_step_result = cast(LlmStepResult, llm_step_result)
# Save citation mapping after each LLM step for incremental state updates
state_container.set_citation_mapping(citation_processor.citation_to_doc)
@@ -463,39 +480,20 @@ def run_llm_loop(
tool_responses: list[ToolResponse] = []
tool_calls = llm_step_result.tool_calls or []
# Quick note for why citation_mapping and citation_processors are both needed:
# 1. Tools return lightweight string mappings, not SearchDoc objects
# 2. The SearchDoc resolution is deliberately deferred to llm_loop.py
# 3. The citation_processor operates on SearchDoc objects and can't provide a complete reverse URL lookup for
# in-flight citations
# It can be cleaned up but not super trivial or worthwhile right now
just_ran_web_search = False
tool_responses, citation_mapping = run_tool_calls(
tool_calls=tool_calls,
tools=final_tools,
message_history=truncated_message_history,
memories=memories,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
citation_processor=citation_processor,
skip_search_query_expansion=has_called_search_tool,
)
# Failure case, give something reasonable to the LLM to try again
if tool_calls and not tool_responses:
failure_messages = create_tool_call_failure_messages(
tool_calls[0], token_counter
for tool_call in tool_calls:
# TODO replace the [tool_call] with the list of tool calls once parallel tool calls are supported
tool_responses, citation_mapping = run_tool_calls(
tool_calls=[tool_call],
tools=final_tools,
turn_index=current_tool_call_index,
message_history=truncated_message_history,
memories=memories,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
citation_processor=citation_processor,
skip_search_query_expansion=has_called_search_tool,
)
simple_chat_history.extend(failure_messages)
continue
for tool_response in tool_responses:
# Extract tool_call from the response (set by run_tool_calls)
if tool_response.tool_call is None:
raise ValueError("Tool response missing tool_call reference")
tool_call = tool_response.tool_call
tab_index = tool_call.tab_index
# Track if search tool was called (for skipping query expansion on subsequent calls)
if tool_call.tool_name == SearchTool.NAME:
@@ -504,103 +502,110 @@ def run_llm_loop(
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}
# Add the results to the chat history. Even though tools may run in parallel,
# LLM APIs require linear history, so results are added sequentially.
# Get the tool object to retrieve tool_id
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
raise ValueError(
f"Tool '{tool_call.tool_name}' not found in tools list"
)
# Add the results to the chat history, note that even if the tools were run in parallel, this isn't supported
# as all the LLM APIs require linear history, so these will just be included sequentially
for tool_call, tool_response in zip([tool_call], tool_responses):
# Get the tool object to retrieve tool_id
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
raise ValueError(
f"Tool '{tool_call.tool_name}' not found in tools list"
)
# Extract search_docs if this is a search tool response
search_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
if gathered_documents:
gathered_documents.extend(search_docs)
else:
gathered_documents = search_docs
# This is used for the Open URL reminder in the next cycle
# only do this if the web search tool yielded results
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
just_ran_web_search = True
# Extract generated_images if this is an image generation tool response
generated_images = None
if isinstance(
tool_response.rich_response, FinalImageGenerationResponse
):
generated_images = tool_response.rich_response.generated_images
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=llm_cycle_count + reasoning_cycles,
tab_index=tab_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
generated_images=generated_images,
)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)
# Store tool call with function name and arguments in separate layers
tool_call_message = tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
tool_response_message = tool_response.llm_facing_response
tool_response_token_count = token_counter(tool_response_message)
tool_response_msg = ChatMessageSimple(
message=tool_response_message,
token_count=tool_response_token_count,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_response_msg)
# Update citation processor if this was a search tool
if tool_call.tool_name in citeable_tools_names:
# Check if the rich_response is a SearchDocsResponse
# Extract search_docs if this is a search tool response
search_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_response = tool_response.rich_response
search_docs = tool_response.rich_response.search_docs
if gathered_documents:
gathered_documents.extend(search_docs)
else:
gathered_documents = search_docs
# Create mapping from citation number to SearchDoc
citation_to_doc: dict[int, SearchDoc] = {}
for (
citation_num,
doc_id,
) in search_response.citation_mapping.items():
# Find the SearchDoc with this doc_id
matching_doc = next(
(
doc
for doc in search_response.search_docs
if doc.document_id == doc_id
),
None,
)
if matching_doc:
citation_to_doc[citation_num] = matching_doc
# This is used for the Open URL reminder in the next cycle
# only do this if the web search tool yielded results
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
just_ran_web_search = True
# Update the citation processor
citation_processor.update_citation_mapping(citation_to_doc)
# Extract generated_images if this is an image generation tool response
generated_images = None
if isinstance(
tool_response.rich_response, FinalImageGenerationResponse
):
generated_images = tool_response.rich_response.generated_images
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=current_tool_call_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
generated_images=generated_images,
)
collected_tool_calls.append(tool_call_info)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)
# Store tool call with function name and arguments in separate layers
tool_call_data = {
TOOL_CALL_MSG_FUNC_NAME: tool_call.tool_name,
TOOL_CALL_MSG_ARGUMENTS: tool_call.tool_args,
}
tool_call_message = json.dumps(tool_call_data)
tool_call_token_count = token_counter(tool_call_message)
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
tool_response_message = tool_response.llm_facing_response
tool_response_token_count = token_counter(tool_response_message)
tool_response_msg = ChatMessageSimple(
message=tool_response_message,
token_count=tool_response_token_count,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_response_msg)
# Update citation processor if this was a search tool
if tool_call.tool_name in citeable_tools_names:
# Check if the rich_response is a SearchDocsResponse
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_response = tool_response.rich_response
# Create mapping from citation number to SearchDoc
citation_to_doc: dict[int, SearchDoc] = {}
for (
citation_num,
doc_id,
) in search_response.citation_mapping.items():
# Find the SearchDoc with this doc_id
matching_doc = next(
(
doc
for doc in search_response.search_docs
if doc.document_id == doc_id
),
None,
)
if matching_doc:
citation_to_doc[citation_num] = matching_doc
# Update the citation processor
citation_processor.update_citation_mapping(citation_to_doc)
current_tool_call_index += 1
# If no tool calls, then it must have answered, wrap up
if not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0:
@@ -624,8 +629,5 @@ def run_llm_loop(
raise RuntimeError("LLM did not return an answer.")
emitter.emit(
Packet(
turn_index=llm_cycle_count + reasoning_cycles,
obj=OverallStop(type="stop"),
)
Packet(turn_index=current_tool_call_index, obj=OverallStop(type="stop"))
)

View File

@@ -1,16 +1,10 @@
import json
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from onyx.chat.emitter import Emitter
from onyx.llm.models import ReasoningEffort
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.models import ChatMessageSimple
@@ -23,7 +17,6 @@ from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.model_response import Delta
from onyx.llm.models import AssistantMessage
from onyx.llm.models import ChatCompletionMessage
from onyx.llm.models import FunctionCall
@@ -41,8 +34,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDone
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.tools.models import TOOL_CALL_MSG_ARGUMENTS
from onyx.tools.models import TOOL_CALL_MSG_FUNC_NAME
from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
@@ -52,43 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
"""Parse tool arguments into a dict.
Normal case:
- raw_args == '{"queries":[...]}' -> dict via json.loads
Defensive case (JSON string literal of an object):
- raw_args == '"{\\"queries\\":[...]}"' -> json.loads -> str -> json.loads -> dict
Anything else returns {}.
"""
if raw_args is None:
return {}
if isinstance(raw_args, dict):
return raw_args
if not isinstance(raw_args, str):
return {}
try:
parsed1: Any = json.loads(raw_args)
except json.JSONDecodeError:
return {}
if isinstance(parsed1, dict):
return parsed1
if isinstance(parsed1, str):
try:
parsed2: Any = json.loads(parsed1)
except json.JSONDecodeError:
return {}
return parsed2 if isinstance(parsed2, dict) else {}
return {}
TOOL_CALL_MSG_FUNC_NAME = "function_name"
TOOL_CALL_MSG_ARGUMENTS = "arguments"
def _format_message_history_for_logging(
@@ -197,19 +153,21 @@ def _update_tool_call_with_delta(
def _extract_tool_call_kickoffs(
id_to_tool_call_map: dict[int, dict[str, Any]],
turn_index: int,
) -> list[ToolCallKickoff]:
"""Extract ToolCallKickoff objects from the tool call map.
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
Each tool call is assigned the given turn_index and a tab_index based on its order.
"""
tool_calls: list[ToolCallKickoff] = []
tab_index = 0
for tool_call_data in id_to_tool_call_map.values():
if tool_call_data.get("id") and tool_call_data.get("name"):
try:
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
# Parse arguments JSON string to dict
tool_args = (
json.loads(tool_call_data["arguments"])
if tool_call_data["arguments"]
else {}
)
except json.JSONDecodeError:
# If parsing fails, try empty dict, most tools would fail though
logger.error(
@@ -222,11 +180,8 @@ def _extract_tool_call_kickoffs(
tool_call_id=tool_call_data["id"],
tool_name=tool_call_data["name"],
tool_args=tool_args,
turn_index=turn_index,
tab_index=tab_index,
)
)
tab_index += 1
return tool_calls
@@ -317,19 +272,13 @@ def translate_history_to_llm_format(
function_name = tool_call_data.get(
TOOL_CALL_MSG_FUNC_NAME, "unknown"
)
raw_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
else:
function_name = "unknown"
raw_args = (
tool_args = (
tool_call_data if isinstance(tool_call_data, dict) else {}
)
# IMPORTANT: `FunctionCall.arguments` must be a JSON object string.
# If `raw_args` is accidentally a JSON string literal of an object
# (e.g. '"{\\"queries\\":[...]}"'), calling `json.dumps(raw_args)`
# would produce a quoted JSON literal and break Anthropic tool parsing.
tool_args = _parse_tool_args_to_dict(raw_args)
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
tool_call = ToolCall(
@@ -375,25 +324,20 @@ def translate_history_to_llm_format(
return messages
def run_llm_step_pkt_generator(
def run_llm_step(
history: list[ChatMessageSimple],
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
turn_index: int,
citation_processor: DynamicCitationProcessor,
state_container: ChatStateContainer,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
final_documents: list[SearchDoc] | None = None,
user_identity: LLMUserIdentity | None = None,
custom_token_processor: (
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
) = None,
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
# The second return value is for the turn index because reasoning counts on the frontend as a turn
# TODO this is maybe ok but does not align well with the backend logic too well
llm_msg_history = translate_history_to_llm_format(history)
has_reasoned = 0
# Uncomment the line below to log the entire message history to the console
if LOG_ONYX_MODEL_INTERACTIONS:
@@ -407,8 +351,6 @@ def run_llm_step_pkt_generator(
accumulated_reasoning = ""
accumulated_answer = ""
processor_state: Any = None
with generation_span(
model=llm.config.model_name,
model_config={
@@ -424,7 +366,7 @@ def run_llm_step_pkt_generator(
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=None, # TODO
reasoning_effort=reasoning_effort,
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
user_identity=user_identity,
):
if packet.usage:
@@ -437,17 +379,6 @@ def run_llm_step_pkt_generator(
}
delta = packet.choice.delta
if custom_token_processor:
# The custom token processor can modify the deltas for specific custom logic
# It can also return a state so that it can handle aggregated delta logic etc.
# Loosely typed so the function can be flexible
modified_delta, processor_state = custom_token_processor(
delta, processor_state
)
if modified_delta is None:
continue
delta = modified_delta
# Should only happen once, frontend does not expect multiple
# ReasoningStart or ReasoningDone packets.
if delta.reasoning_content:
@@ -471,42 +402,32 @@ def run_llm_step_pkt_generator(
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index += 1
reasoning_start = False
if not answer_start:
yield Packet(
turn_index=turn_index + has_reasoned,
turn_index=turn_index,
obj=AgentResponseStart(
final_documents=final_documents,
),
)
answer_start = True
if citation_processor:
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
turn_index=turn_index + has_reasoned,
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
turn_index=turn_index + has_reasoned,
obj=result,
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
turn_index=turn_index + has_reasoned,
obj=AgentResponseDelta(content=delta.content),
)
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
turn_index=turn_index,
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
turn_index=turn_index,
obj=result,
)
if delta.tool_calls:
if reasoning_start:
@@ -514,22 +435,13 @@ def run_llm_step_pkt_generator(
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index += 1
reasoning_start = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
# Flush custom token processor to get any final tool calls
if custom_token_processor:
flush_delta, processor_state = custom_token_processor(None, processor_state)
if flush_delta and flush_delta.tool_calls:
for tool_call_delta in flush_delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
tool_calls = _extract_tool_call_kickoffs(
id_to_tool_call_map, turn_index + has_reasoned
)
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
if tool_calls:
tool_calls_list: list[ToolCall] = [
ToolCall(
@@ -562,7 +474,7 @@ def run_llm_step_pkt_generator(
turn_index=turn_index,
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index += 1
# Flush any remaining content from citation processor
if citation_processor:
@@ -572,12 +484,12 @@ def run_llm_step_pkt_generator(
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
turn_index=turn_index + has_reasoned,
turn_index=turn_index,
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
turn_index=turn_index + has_reasoned,
turn_index=turn_index,
obj=result,
)
@@ -602,49 +514,5 @@ def run_llm_step_pkt_generator(
answer=accumulated_answer if accumulated_answer else None,
tool_calls=tool_calls if tool_calls else None,
),
bool(has_reasoned),
turn_index,
)
def run_llm_step(
emitter: "Emitter",
history: list[ChatMessageSimple],
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
turn_index: int,
state_container: ChatStateContainer,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
final_documents: list[SearchDoc] | None = None,
user_identity: LLMUserIdentity | None = None,
custom_token_processor: (
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
) = None,
) -> tuple[LlmStepResult, bool]:
"""Wrapper around run_llm_step_pkt_generator that consumes packets and emits them.
Returns:
tuple[LlmStepResult, bool]: The LLM step result and whether reasoning occurred.
"""
step_generator = run_llm_step_pkt_generator(
history=history,
tool_definitions=tool_definitions,
tool_choice=tool_choice,
llm=llm,
turn_index=turn_index,
state_container=state_container,
citation_processor=citation_processor,
reasoning_effort=reasoning_effort,
final_documents=final_documents,
user_identity=user_identity,
custom_token_processor=custom_token_processor,
)
while True:
try:
packet = next(step_generator)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, bool(has_reasoned)

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.chat.chat_milestones import process_multi_assistant_milestone
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_llm_with_state_containers
from onyx.chat.chat_utils import convert_chat_history
@@ -31,7 +32,6 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import create_new_chat_message
@@ -51,8 +51,8 @@ from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.utils import litellm_exception_to_error_msg
@@ -72,7 +72,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_constructor import SearchToolUsage
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
@@ -368,10 +367,11 @@ def stream_chat_message_objects(
)
# Milestone tracking, most devs using the API don't need to understand this
mt_cloud_telemetry(
process_multi_assistant_milestone(
user=user,
assistant_id=persona.id,
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
db_session=db_session,
)
if reference_doc_ids is None and retrieval_options is None:
@@ -379,7 +379,7 @@ def stream_chat_message_objects(
"Must specify a set of documents for chat or specify search options"
)
llm = get_llm_for_persona(
llm, fast_llm = get_llms_for_persona(
persona=persona,
user=user,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
@@ -475,6 +475,7 @@ def stream_chat_message_objects(
emitter=emitter,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=user_selected_filters,
project_id=(

View File

@@ -102,7 +102,6 @@ def _create_and_link_tool_calls(
if tool_call_info.generated_images
else None
),
tab_index=tool_call_info.tab_index,
add_only=True,
)
@@ -220,8 +219,8 @@ def save_chat_turn(
search_doc_key_to_id[search_doc_key] = db_search_doc.id
search_doc_ids_for_tool.append(db_search_doc.id)
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
set(search_doc_ids_for_tool)
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = (
search_doc_ids_for_tool
)
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage

View File

@@ -332,6 +332,7 @@ class FileType(str, Enum):
class MilestoneRecordType(str, Enum):
TENANT_CREATED = "tenant_created"
USER_SIGNED_UP = "user_signed_up"
MULTIPLE_USERS = "multiple_users"
VISITED_ADMIN_PAGE = "visited_admin_page"
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"

View File

@@ -51,9 +51,10 @@ CROSS_ENCODER_RANGE_MIN = 0
# Generative AI Model Configs
#####
# NOTE: the 2 below should only be used for dev.
# NOTE: the 3 below should only be used for dev.
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None

View File

@@ -40,7 +40,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
@@ -409,7 +410,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
continue
# Handle image files
if file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
if not self._allow_images:
logger.debug(
f"Skipping image file: {key} (image processing not enabled)"

View File

@@ -84,12 +84,6 @@ ONE_DAY = ONE_HOUR * 24
MAX_CACHED_IDS = 100
def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str:
if allow_missing and "id" not in page:
return "unknown"
return str(page["id"])
class ConfluenceCheckpoint(ConnectorCheckpoint):
next_page_url: str | None
@@ -305,7 +299,7 @@ class ConfluenceConnector(
page_id = page_url = ""
try:
# Extract basic page information
page_id = _get_page_id(page)
page_id = page["id"]
page_title = page["title"]
logger.info(f"Converting page {page_title} to document")
page_url = build_confluence_document_id(
@@ -388,9 +382,7 @@ class ConfluenceConnector(
this function. The returned documents/connectorfailures are for non-inline attachments
and those at the end of the page.
"""
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_query = self._construct_attachment_query(page["id"], start, end)
attachment_failures: list[ConnectorFailure] = []
attachment_docs: list[Document] = []
page_url = ""
@@ -438,7 +430,7 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=_get_page_id(page),
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
@@ -523,21 +515,14 @@ class ConfluenceConnector(
except HTTPError as e:
# If we get a 403 after all retries, the user likely doesn't have permission
# to access attachments on this page. Log and skip rather than failing the whole job.
page_id = _get_page_id(page, allow_missing=True)
page_title = page.get("title", "unknown")
if e.response and e.response.status_code in [401, 403]:
failure_message_prefix = (
"Invalid credentials (401)"
if e.response.status_code == 401
else "Permission denied (403)"
)
failure_message = (
f"{failure_message_prefix} when fetching attachments for page '{page_title}' "
if e.response and e.response.status_code == 403:
page_title = page.get("title", "unknown")
page_id = page.get("id", "unknown")
logger.warning(
f"Permission denied (403) when fetching attachments for page '{page_title}' "
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
"Skipping attachments for this page."
)
logger.warning(failure_message)
# Build the page URL for the failure record
try:
page_url = build_confluence_document_id(
@@ -552,7 +537,7 @@ class ConfluenceConnector(
document_id=page_id,
document_link=page_url,
),
failure_message=failure_message,
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
exception=e,
)
]
@@ -723,7 +708,7 @@ class ConfluenceConnector(
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
page_id = _get_page_id(page)
page_id = page["id"]
page_restrictions = page.get("restrictions") or {}
page_space_key = page.get("space", {}).get("key")
page_ancestors = page.get("ancestors", [])
@@ -743,7 +728,7 @@ class ConfluenceConnector(
)
# Query attachments for each page
attachment_query = self._construct_attachment_query(_get_page_id(page))
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -24,9 +24,9 @@ from onyx.configs.app_configs import (
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
@@ -56,13 +56,15 @@ def validate_attachment_filetype(
"""
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return media_type in OnyxMimeTypes.IMAGE_MIME_TYPES
return is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
extension = get_file_ext(title)
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
return extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS
return is_accepted_file_ext(
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
)
class AttachmentProcessingResult(BaseModel):

View File

@@ -28,8 +28,10 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -68,15 +70,14 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
# Explicitly excluding image extensions here. TODO: consider allowing images
if extension not in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None
# Extract text content based on file type
# TODO @wenxi-onyx: convert to extract_text_and_images
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
if is_text_file_extension(file_name):
encoding = detect_encoding(file_content)
file_content_raw, file_metadata = read_text_file(
file_content, encoding=encoding, ignore_onyx_metadata=False

View File

@@ -18,7 +18,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -89,7 +90,7 @@ def _process_file(
# Get file extension and determine file type
extension = get_file_ext(file_name)
if extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
@@ -110,7 +111,7 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
if extension in LoadConnector.IMAGE_EXTENSIONS:
# Read the image data
image_data = file.read()
if not image_data:

View File

@@ -29,14 +29,14 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
@@ -114,6 +114,14 @@ def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
return urlunparse(parsed_url)
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
def download_request(
service: GoogleDriveService, file_id: str, size_threshold: int
) -> bytes:
@@ -165,7 +173,7 @@ def _download_and_extract_sections_basic(
def response_call() -> bytes:
return download_request(service, file_id, size_threshold)
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
if is_gdrive_image_mime_type(mime_type):
# Skip images if not explicitly enabled
if not allow_images:
return []
@@ -252,7 +260,7 @@ def _download_and_extract_sections_basic(
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []

View File

@@ -23,8 +23,9 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -308,7 +309,10 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
elif (
is_valid_format
and file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS
and (
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
)
and can_download
):
content_response = self.client.get_item_content(item_id)

View File

@@ -27,6 +27,8 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -1,5 +1,4 @@
import copy
import json
import os
from collections.abc import Callable
from collections.abc import Iterable
@@ -10,7 +9,6 @@ from datetime import timezone
from typing import Any
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
from more_itertools import chunked
from typing_extensions import override
@@ -136,80 +134,6 @@ def _perform_jql_search(
return _perform_jql_search_v2(jira_client, jql, start, max_results, fields)
def _handle_jira_search_error(e: Exception, jql: str) -> None:
"""Handle common Jira search errors and raise appropriate exceptions.
Args:
e: The exception raised by the Jira API
jql: The JQL query that caused the error
Raises:
ConnectorValidationError: For HTTP 400 errors (invalid JQL or project)
CredentialExpiredError: For HTTP 401 errors
InsufficientPermissionsError: For HTTP 403 errors
Exception: Re-raises the original exception for other error types
"""
# Extract error information from the exception
error_text = ""
status_code = None
def _format_error_text(error_payload: Any) -> str:
error_messages = (
error_payload.get("errorMessages", [])
if isinstance(error_payload, dict)
else []
)
if error_messages:
return (
"; ".join(error_messages)
if isinstance(error_messages, list)
else str(error_messages)
)
return str(error_payload)
# Try to get status code and error text from JIRAError or requests response
if hasattr(e, "status_code"):
status_code = e.status_code
raw_text = getattr(e, "text", "")
if isinstance(raw_text, str):
try:
error_text = _format_error_text(json.loads(raw_text))
except Exception:
error_text = raw_text
else:
error_text = str(raw_text)
elif hasattr(e, "response") and e.response is not None:
status_code = e.response.status_code
# Try JSON first, fall back to text
try:
error_json = e.response.json()
error_text = _format_error_text(error_json)
except Exception:
error_text = e.response.text
# Handle specific status codes
if status_code == 400:
if "does not exist for the field 'project'" in error_text:
raise ConnectorValidationError(
f"The specified Jira project does not exist or you don't have access to it. "
f"JQL query: {jql}. Error: {error_text}"
)
raise ConnectorValidationError(
f"Invalid JQL query. JQL: {jql}. Error: {error_text}"
)
elif status_code == 401:
raise CredentialExpiredError(
"Jira credentials are expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
f"Insufficient permissions to execute JQL query. JQL: {jql}"
)
# Re-raise for other error types
raise e
def enhanced_search_ids(
jira_client: JIRA, jql: str, nextPageToken: str | None = None
) -> tuple[list[str], str | None]:
@@ -225,15 +149,8 @@ def enhanced_search_ids(
"nextPageToken": nextPageToken,
"fields": "id",
}
try:
response = jira_client._session.get(enhanced_search_path, params=params)
response.raise_for_status()
response_json = response.json()
except Exception as e:
_handle_jira_search_error(e, jql)
raise # Explicitly re-raise for type checker, should never reach here
return [str(issue["id"]) for issue in response_json["issues"]], response_json.get(
response = jira_client._session.get(enhanced_search_path, params=params).json()
return [str(issue["id"]) for issue in response["issues"]], response.get(
"nextPageToken"
)
@@ -315,16 +232,12 @@ def _perform_jql_search_v2(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
try:
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
except JIRAError as e:
_handle_jira_search_error(e, jql)
raise # Explicitly re-raise for type checker, should never reach here
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
for issue in issues:
if isinstance(issue, Issue):

View File

@@ -55,10 +55,12 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
@@ -326,7 +328,7 @@ def _convert_driveitem_to_document_with_permissions(
try:
item_json = driveitem.to_json()
mime_type = item_json.get("file", {}).get("mimeType")
if not mime_type or mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
# for now, this skip must happen before we download the file
# Similar to Google Drive, we'll just semi-silently skip excluded image types
@@ -386,14 +388,14 @@ def _convert_driveitem_to_document_with_permissions(
return None
sections: list[TextSection | ImageSection] = []
file_ext = get_file_ext(driveitem.name)
file_ext = driveitem.name.split(".")[-1]
if not content_bytes:
logger.warning(
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
)
elif file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
# NOTE: this if should probably check mime_type instead
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
# NOTE: this if should use is_valid_image_type instead with mime_type
image_section, _ = store_image_and_create_section(
image_data=content_bytes,
file_id=driveitem.id,
@@ -416,7 +418,7 @@ def _convert_driveitem_to_document_with_permissions(
# The only mime type that would be returned by get_image_type_from_bytes that is in
# EXCLUDED_IMAGE_TYPES is image/gif.
if mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
if mime_type in EXCLUDED_IMAGE_TYPES:
logger.debug(
"Skipping embedded image of excluded type %s for %s",
mime_type,
@@ -1504,7 +1506,7 @@ class SharepointConnector(
)
for driveitem in driveitems:
driveitem_extension = get_file_ext(driveitem.name)
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
logger.warning(
f"Skipping {driveitem.web_url} as it is not a supported file type"
)
@@ -1512,7 +1514,7 @@ class SharepointConnector(
# Only yield empty documents if they are PDFs or images
should_yield_if_empty = (
driveitem_extension in OnyxFileExtensions.IMAGE_EXTENSIONS
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
or driveitem_extension == ".pdf"
)

View File

@@ -18,7 +18,6 @@ from oauthlib.oauth2 import BackendApplicationClient
from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from playwright.sync_api import TimeoutError
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
@@ -87,8 +86,6 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
# Grace period after page navigation to allow bot-detection challenges to complete
BOT_DETECTION_GRACE_PERIOD_MS = 5000
# Define common headers that mimic a real browser
DEFAULT_USER_AGENT = (
@@ -557,17 +554,12 @@ class WebConnector(LoadConnector):
page = session_ctx.playwright_context.new_page()
try:
# Use "commit" instead of "domcontentloaded" to avoid hanging on bot-detection pages
# that may never fire domcontentloaded. "commit" waits only for navigation to be
# committed (response received), then we add a short wait for initial rendering.
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
wait_until="commit", # Wait for navigation to commit
wait_until="domcontentloaded", # Wait for DOM to be ready
)
# Give the page a moment to start rendering after navigation commits.
# Allows CloudFlare and other bot-detection challenges to complete.
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
last_modified = (
page_response.header_value("Last-Modified") if page_response else None
@@ -592,15 +584,8 @@ class WebConnector(LoadConnector):
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
# Wait for content to load, but catch timeout if page never reaches networkidle
# (e.g., CloudFlare protection keeps making requests)
try:
page.wait_for_load_state(
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
)
except TimeoutError:
# If networkidle times out, just give it a moment for content to render
time.sleep(1)
# wait for the content to load if we scrolled
page.wait_for_load_state("networkidle", timeout=30000)
time.sleep(0.5) # let javascript run
new_height = page.evaluate("document.body.scrollHeight")

View File

@@ -39,7 +39,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.onyxbot.slack.models import ChannelType
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
@@ -830,8 +830,8 @@ def slack_retrieval(
)
# Query slack with entity filtering
llm = get_default_llm()
query_strings = build_slack_queries(query, llm, entities, available_channels)
_, fast_llm = get_default_llms()
query_strings = build_slack_queries(query, fast_llm, entities, available_channels)
# Determine filtering based on entities OR context (bot)
include_dm = False

View File

@@ -234,6 +234,9 @@ def upsert_llm_provider(
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.fast_default_model_name = (
llm_provider_upsert_request.fast_default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name

View File

@@ -1,9 +1,7 @@
import datetime
from typing import cast
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
@@ -26,9 +24,7 @@ logger = setup_logger()
# MCPServer operations
def get_all_mcp_servers(db_session: Session) -> list[MCPServer]:
"""Get all MCP servers"""
return list(
db_session.scalars(select(MCPServer).order_by(MCPServer.created_at)).all()
)
return list(db_session.scalars(select(MCPServer)).all())
def get_mcp_server_by_id(server_id: int, db_session: Session) -> MCPServer:
@@ -128,7 +124,6 @@ def update_mcp_server__no_commit(
auth_performer: MCPAuthenticationPerformer | None = None,
transport: MCPTransport | None = None,
status: MCPServerStatus | None = None,
last_refreshed_at: datetime.datetime | None = None,
) -> MCPServer:
"""Update an existing MCP server"""
server = get_mcp_server_by_id(server_id, db_session)
@@ -149,8 +144,6 @@ def update_mcp_server__no_commit(
server.transport = transport
if status is not None:
server.status = status
if last_refreshed_at is not None:
server.last_refreshed_at = last_refreshed_at
db_session.flush() # Don't commit yet, let caller decide when to commit
return server
@@ -337,15 +330,3 @@ def delete_user_connection_configs_for_server(
db_session.delete(config)
db_session.commit()
def delete_all_user_connection_configs_for_server_no_commit(
server_id: int, db_session: Session
) -> None:
"""Delete all user connection configs for a specific MCP server"""
db_session.execute(
delete(MCPConnectionConfig).where(
MCPConnectionConfig.mcp_server_id == server_id
)
)
db_session.flush() # Don't commit yet, let caller decide when to commit

View File

@@ -0,0 +1,99 @@
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from onyx.configs.constants import MilestoneRecordType
from onyx.db.models import Milestone
from onyx.db.models import User
USER_ASSISTANT_PREFIX = "user_assistants_used_"
MULTI_ASSISTANT_USED = "multi_assistant_used"
def create_milestone(
user: User | None,
event_type: MilestoneRecordType,
db_session: Session,
) -> Milestone:
milestone = Milestone(
event_type=event_type,
user_id=user.id if user else None,
)
db_session.add(milestone)
db_session.commit()
return milestone
def create_milestone_if_not_exists(
user: User | None, event_type: MilestoneRecordType, db_session: Session
) -> tuple[Milestone, bool]:
# Check if it exists
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one_or_none()
if milestone is not None:
return milestone, False
# If it doesn't exist, try to create it.
try:
milestone = create_milestone(user, event_type, db_session)
return milestone, True
except IntegrityError:
# Another thread or process inserted it in the meantime
db_session.rollback()
# Fetch again to return the existing record
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one() # Now should exist
return milestone, False
def update_user_assistant_milestone(
milestone: Milestone,
user_id: str | None,
assistant_id: int,
db_session: Session,
) -> None:
event_tracker = milestone.event_tracker
if event_tracker is None:
milestone.event_tracker = event_tracker = {}
if event_tracker.get(MULTI_ASSISTANT_USED):
# No need to keep tracking and populating if the milestone has already been hit
return
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
if event_tracker.get(user_key) is None:
event_tracker[user_key] = [assistant_id]
elif assistant_id not in event_tracker[user_key]:
event_tracker[user_key].append(assistant_id)
flag_modified(milestone, "event_tracker")
db_session.commit()
def check_multi_assistant_milestone(
milestone: Milestone,
db_session: Session,
) -> tuple[bool, bool]:
"""Returns if the milestone was hit and if it was just hit for the first time"""
event_tracker = milestone.event_tracker
if event_tracker is None:
return False, False
if event_tracker.get(MULTI_ASSISTANT_USED):
return True, False
for key, value in event_tracker.items():
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
event_tracker[MULTI_ASSISTANT_USED] = True
flag_modified(milestone, "event_tracker")
db_session.commit()
return True, True
return False, False

View File

@@ -2215,8 +2215,6 @@ class ToolCall(Base):
# The tools with the same turn number (and parent) were called in parallel
# Ones with different turn numbers (and same parent) were called sequentially
turn_number: Mapped[int] = mapped_column(Integer)
# Index order of tool calls from the LLM for parallel tool calls
tab_index: Mapped[int] = mapped_column(Integer, default=0)
# Not a FK because we want to be able to delete the tool without deleting
# this entry
@@ -2384,6 +2382,7 @@ class LLMProvider(Base):
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
@@ -3674,9 +3673,6 @@ class MCPServer(Base):
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
last_refreshed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Relationships
admin_connection_config: Mapped["MCPConnectionConfig | None"] = relationship(
@@ -3689,7 +3685,6 @@ class MCPServer(Base):
"MCPConnectionConfig",
foreign_keys="MCPConnectionConfig.mcp_server_id",
back_populates="mcp_server",
passive_deletes=True,
)
current_actions: Mapped[list["Tool"]] = relationship(
"Tool", back_populates="mcp_server", cascade="all, delete-orphan"
@@ -3918,22 +3913,3 @@ class ExternalGroupPermissionSyncAttempt(Base):
def is_finished(self) -> bool:
return self.status.is_terminal()
class License(Base):
"""Stores the signed license blob (singleton pattern - only one row)."""
__tablename__ = "license"
__table_args__ = (
# Singleton pattern - unique index on constant ensures only one row
Index("idx_license_singleton", text("(true)"), unique=True),
)
id: Mapped[int] = mapped_column(primary_key=True)
license_data: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)

View File

@@ -221,7 +221,6 @@ def create_tool_call_no_commit(
parent_tool_call_id: int | None = None,
reasoning_tokens: str | None = None,
generated_images: list[dict] | None = None,
tab_index: int = 0,
add_only: bool = True,
) -> ToolCall:
"""
@@ -240,7 +239,6 @@ def create_tool_call_no_commit(
parent_tool_call_id: Optional parent tool call ID (for nested tool calls)
reasoning_tokens: Optional reasoning tokens
generated_images: Optional list of generated image metadata for replay
tab_index: Index order of tool calls from the LLM for parallel tool calls
commit: If True, commit the transaction; if False, flush only
Returns:
@@ -251,7 +249,6 @@ def create_tool_call_no_commit(
parent_chat_message_id=parent_chat_message_id,
parent_tool_call_id=parent_tool_call_id,
turn_number=turn_number,
tab_index=tab_index,
tool_id=tool_id,
tool_call_id=tool_call_id,
reasoning_tokens=reasoning_tokens,

View File

@@ -42,15 +42,6 @@ def fetch_web_search_provider_by_name(
return db_session.scalars(stmt).first()
def fetch_web_search_provider_by_type(
provider_type: WebSearchProviderType, db_session: Session
) -> InternetSearchProvider | None:
stmt = select(InternetSearchProvider).where(
InternetSearchProvider.provider_type == provider_type.value
)
return db_session.scalars(stmt).first()
def _ensure_unique_search_name(
name: str, provider_id: int | None, db_session: Session
) -> None:
@@ -198,15 +189,6 @@ def fetch_web_content_provider_by_name(
return db_session.scalars(stmt).first()
def fetch_web_content_provider_by_type(
provider_type: WebContentProviderType, db_session: Session
) -> InternetContentProvider | None:
stmt = select(InternetContentProvider).where(
InternetContentProvider.provider_type == provider_type.value
)
return db_session.scalars(stmt).first()
def _ensure_unique_content_name(
name: str, provider_id: int | None, db_session: Session
) -> None:

View File

@@ -12,30 +12,18 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.llm_step import run_llm_step
from onyx.chat.llm_step import run_llm_step_pkt_generator
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.configs.constants import MessageType
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
from onyx.deep_research.utils import check_special_tool_calls
from onyx.deep_research.utils import create_think_tool_token_processor
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import ToolChoiceOptions
from onyx.llm.utils import model_is_reasoning_model
from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT
from onyx.prompts.deep_research.orchestration_layer import FINAL_REPORT_PROMPT
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY
from onyx.prompts.prompt_utils import get_current_llm_day_time
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
@@ -43,10 +31,6 @@ from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanStart
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.fake_tools.research_agent import run_research_agent_calls
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -56,71 +40,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
MAX_USER_MESSAGES_FOR_CONTEXT = 5
# Might be something like:
# 1. Research 1-2
# 2. Think
# 3. Research 3-4
# 4. Think
# 5. Research 5-6
# 6. Think
# 7. Research, possibly something new or different from the plan
# 8. Think
# 9. Generate report
MAX_ORCHESTRATOR_CYCLES = 9
# Similar but without the 4 thinking tool calls
MAX_ORCHESTRATOR_CYCLES_REASONING = 5
def generate_final_report(
history: list[ChatMessageSimple],
llm: LLM,
token_counter: Callable[[str], int],
state_container: ChatStateContainer,
emitter: Emitter,
user_identity: LLMUserIdentity | None,
) -> None:
final_report_prompt = FINAL_REPORT_PROMPT.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
)
system_prompt = ChatMessageSimple(
message=final_report_prompt,
token_count=token_counter(final_report_prompt),
message_type=MessageType.SYSTEM,
)
reminder_message = ChatMessageSimple(
message=USER_FINAL_REPORT_QUERY,
token_count=token_counter(USER_FINAL_REPORT_QUERY),
message_type=MessageType.USER,
)
final_report_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
project_files=None,
available_tokens=llm.config.max_input_tokens,
)
llm_step_result, _ = run_llm_step(
emitter=emitter,
history=final_report_history,
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
turn_index=999, # TODO
citation_processor=None,
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
final_documents=None,
user_identity=user_identity,
)
final_report = llm_step_result.answer
if final_report is None:
raise ValueError("LLM failed to generate the final deep research report")
state_container.set_answer_tokens(final_report)
MAX_ORCHESTRATOR_CYCLES = 8
def run_deep_research_llm_loop(
@@ -153,8 +73,7 @@ def run_deep_research_llm_loop(
# Filter tools to only allow web search, internal search, and open URL
allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME}
allowed_tools = [tool for tool in tools if tool.name in allowed_tool_names]
orchestrator_start_turn_index = 0
[tool for tool in tools if tool.name in allowed_tool_names]
#########################################################
# CLARIFICATION STEP (optional)
@@ -179,8 +98,7 @@ def run_deep_research_llm_loop(
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
llm_step_result, _ = run_llm_step(
emitter=emitter,
step_generator = run_llm_step(
history=truncated_message_history,
tool_definitions=get_clarification_tool_definitions(),
tool_choice=ToolChoiceOptions.AUTO,
@@ -188,12 +106,24 @@ def run_deep_research_llm_loop(
turn_index=0,
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=None,
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
final_documents=None,
user_identity=user_identity,
)
# Consume the generator, emitting packets and capturing the final result
while True:
try:
packet = next(step_generator)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, _ = e.value
break
# Type narrowing: generator always returns a result, so this can't be None
llm_step_result = cast(LlmStepResult, llm_step_result)
if not llm_step_result.tool_calls:
# Mark this turn as a clarification question
state_container.set_is_clarification(True)
@@ -224,13 +154,15 @@ def run_deep_research_llm_loop(
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
research_plan_generator = run_llm_step_pkt_generator(
research_plan_generator = run_llm_step(
history=truncated_message_history,
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
turn_index=0,
citation_processor=None,
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
final_documents=None,
user_identity=user_identity,
@@ -240,7 +172,6 @@ def run_deep_research_llm_loop(
try:
packet = next(research_plan_generator)
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
# The LLM response from this prompt is the research plan
if isinstance(packet.obj, AgentResponseStart):
emitter.emit(
Packet(
@@ -259,16 +190,7 @@ def run_deep_research_llm_loop(
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, orchestrator_start_turn_index = e.value
# TODO: All that is done with the plan is for streaming to the frontend and informing the flow
# Currently not saved. It would have to be saved as a ToolCall for a new tool type.
emitter.emit(
Packet(
# Marks the last turn end which should be the plan generation
turn_index=orchestrator_start_turn_index - 1,
obj=SectionEnd(),
)
)
llm_step_result, _ = e.value
break
llm_step_result = cast(LlmStepResult, llm_step_result)
@@ -281,12 +203,6 @@ def run_deep_research_llm_loop(
llm.config.model_name, llm.config.model_provider
)
max_orchestrator_cycles = (
MAX_ORCHESTRATOR_CYCLES
if not is_reasoning_model
else MAX_ORCHESTRATOR_CYCLES_REASONING
)
orchestrator_prompt_template = (
ORCHESTRATOR_PROMPT if not is_reasoning_model else ORCHESTRATOR_PROMPT_REASONING
)
@@ -294,19 +210,16 @@ def run_deep_research_llm_loop(
token_count_prompt = orchestrator_prompt_template.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=1,
max_cycles=max_orchestrator_cycles,
max_cycles=MAX_ORCHESTRATOR_CYCLES,
research_plan=research_plan,
)
orchestration_tokens = token_counter(token_count_prompt)
reasoning_cycles = 0
for cycle in range(max_orchestrator_cycles):
research_agent_calls: list[ToolCallKickoff] = []
for cycle in range(MAX_ORCHESTRATOR_CYCLES):
orchestrator_prompt = orchestrator_prompt_template.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=cycle,
max_cycles=max_orchestrator_cycles,
max_cycles=MAX_ORCHESTRATOR_CYCLES,
research_plan=research_plan,
)
@@ -326,163 +239,16 @@ def run_deep_research_llm_loop(
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
# Use think tool processor for non-reasoning models to convert
# think_tool calls to reasoning content
custom_processor = (
create_think_tool_token_processor() if not is_reasoning_model else None
)
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
research_plan_generator = run_llm_step(
history=truncated_message_history,
tool_definitions=get_orchestrator_tools(
include_think_tool=not is_reasoning_model
),
tool_choice=ToolChoiceOptions.REQUIRED,
tool_definitions=[],
tool_choice=ToolChoiceOptions.AUTO,
llm=llm,
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
turn_index=cycle,
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
final_documents=None,
user_identity=user_identity,
custom_token_processor=custom_processor,
)
if has_reasoned:
reasoning_cycles += 1
tool_calls = llm_step_result.tool_calls or []
if not tool_calls and cycle == 0:
raise RuntimeError(
"Deep Research failed to generate any research tasks for the agents."
)
if not tool_calls:
logger.warning("No tool calls found, this should not happen.")
generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
user_identity=user_identity,
)
break
most_recent_reasoning: str | None = None
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
if special_tool_calls.generate_report_tool_call:
generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
user_identity=user_identity,
)
break
elif special_tool_calls.think_tool_call:
think_tool_call = special_tool_calls.think_tool_call
# Only process the THINK_TOOL and skip all other tool calls
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
# we will show it as a separate message.
most_recent_reasoning = state_container.reasoning_tokens
tool_call_message = think_tool_call.to_msg_str()
think_tool_msg = ChatMessageSimple(
message=tool_call_message,
token_count=token_counter(tool_call_message),
message_type=MessageType.TOOL_CALL,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(think_tool_msg)
think_tool_response_msg = ChatMessageSimple(
message=THINK_TOOL_RESPONSE_MESSAGE,
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(think_tool_response_msg)
reasoning_cycles += 1
continue
else:
for tool_call in tool_calls:
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
continue
research_agent_calls.append(tool_call)
if not research_agent_calls:
logger.warning(
"No research agent tool calls found, this should not happen."
)
generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
user_identity=user_identity,
)
break
research_results = run_research_agent_calls(
research_agent_calls=research_agent_calls,
tools=allowed_tools,
emitter=emitter,
state_container=state_container,
llm=llm,
is_reasoning_model=is_reasoning_model,
token_counter=token_counter,
user_identity=user_identity,
)
for tab_index, research_result in enumerate(research_results):
tool_call_info = ToolCallInfo(
parent_tool_call_id=None,
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
tab_index=tab_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
).id,
reasoning_tokens=most_recent_reasoning,
tool_call_arguments=tool_call.tool_args,
tool_call_response=research_result.intermediate_report,
search_docs=research_result.search_docs,
generated_images=None,
)
state_container.add_tool_call(tool_call_info)
tool_call_message = tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
tool_call_response_msg = ChatMessageSimple(
message=research_result.intermediate_report,
token_count=token_counter(research_result.intermediate_report),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_response_msg)
if not special_tool_calls.think_tool_call:
most_recent_reasoning = None

View File

@@ -1,139 +1,18 @@
GENERATE_PLAN_TOOL_NAME = "generate_plan"
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
RESEARCH_AGENT_TOOL_NAME = "research_agent"
RESEARCH_AGENT_TASK_KEY = "task"
GENERATE_REPORT_TOOL_NAME = "generate_report"
THINK_TOOL_NAME = "think_tool"
# ruff: noqa: E501, W605 start
GENERATE_PLAN_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": GENERATE_PLAN_TOOL_NAME,
"description": "No clarification needed, generate a research plan for the user's query.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
RESEARCH_AGENT_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": RESEARCH_AGENT_TOOL_NAME,
"description": "Conduct research on a specific topic.",
"parameters": {
"type": "object",
"properties": {
RESEARCH_AGENT_TASK_KEY: {
"type": "string",
"description": "The research task to investigate, should be 1-2 descriptive sentences outlining the direction of investigation.",
}
},
"required": [RESEARCH_AGENT_TASK_KEY],
},
},
}
GENERATE_REPORT_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": GENERATE_REPORT_TOOL_NAME,
"description": "Generate the final research report from all of the findings. Should be called when all aspects of the user's query have been researched, or maximum cycles are reached.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
THINK_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": THINK_TOOL_NAME,
"description": "Use this for reasoning between research_agent calls and before calling generate_report. Think deeply about key results, identify knowledge gaps, and plan next steps.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": "Your chain of thought reasoning, use paragraph format, no lists.",
}
},
"required": ["reasoning"],
},
},
}
RESEARCH_AGENT_THINK_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": "think_tool",
"description": "Use this for reasoning between research steps. Think deeply about key results, identify knowledge gaps, and plan next steps.",
"parameters": {
"type": "object",
"properties": {
"reasoning": {
"type": "string",
"description": "Your chain of thought reasoning, can be as long as a lengthy paragraph.",
}
},
"required": ["reasoning"],
},
},
}
RESEARCH_AGENT_GENERATE_REPORT_TOOL_DESCRIPTION = {
"type": "function",
"function": {
"name": "generate_report",
"description": "Generate the final research report from all findings. Should be called when research is complete.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
THINK_TOOL_RESPONSE_MESSAGE = "Acknowledged, please continue."
THINK_TOOL_RESPONSE_TOKEN_COUNT = 10
def get_clarification_tool_definitions() -> list[dict]:
return [GENERATE_PLAN_TOOL_DESCRIPTION]
def get_orchestrator_tools(include_think_tool: bool) -> list[dict]:
tools = [
RESEARCH_AGENT_TOOL_DESCRIPTION,
GENERATE_REPORT_TOOL_DESCRIPTION,
return [
{
"type": "function",
"function": {
"name": GENERATE_PLAN_TOOL_NAME,
"description": "No clarification needed, generate a research plan for the user's query.",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
}
]
if include_think_tool:
tools.append(THINK_TOOL_DESCRIPTION)
return tools
def get_research_agent_additional_tool_definitions(
include_think_tool: bool,
) -> list[dict]:
tools = [GENERATE_REPORT_TOOL_DESCRIPTION]
if include_think_tool:
tools.append(RESEARCH_AGENT_THINK_TOOL_DESCRIPTION)
return tools
# ruff: noqa: E501, W605 end

View File

@@ -1,14 +0,0 @@
from pydantic import BaseModel
from onyx.context.search.models import SearchDoc
from onyx.tools.models import ToolCallKickoff
class SpecialToolCalls(BaseModel):
think_tool_call: ToolCallKickoff | None = None
generate_report_tool_call: ToolCallKickoff | None = None
class ResearchAgentCallResult(BaseModel):
intermediate_report: str
search_docs: list[SearchDoc]

View File

@@ -1,168 +0,0 @@
from collections.abc import Callable
from typing import Any
from pydantic import BaseModel
from onyx.deep_research.dr_mock_tools import GENERATE_REPORT_TOOL_NAME
from onyx.deep_research.dr_mock_tools import THINK_TOOL_NAME
from onyx.deep_research.models import SpecialToolCalls
from onyx.llm.model_response import ChatCompletionDeltaToolCall
from onyx.llm.model_response import Delta
from onyx.llm.model_response import FunctionCall
from onyx.tools.models import ToolCallKickoff
# JSON prefixes to detect in think_tool arguments
# The schema is: {"reasoning": "...content..."}
JSON_PREFIX_WITH_SPACE = '{"reasoning": "'
JSON_PREFIX_NO_SPACE = '{"reasoning":"'
class ThinkToolProcessorState(BaseModel):
"""State for tracking think tool processing across streaming deltas."""
think_tool_found: bool = False
think_tool_index: int | None = None
think_tool_id: str | None = None
full_arguments: str = "" # Full accumulated arguments for final tool call
accumulated_args: str = "" # Working buffer for JSON parsing
json_prefix_stripped: bool = False
# Buffer holds content that might be the JSON suffix "}
# We hold back 2 chars to avoid emitting the closing "}
buffer: str = ""
def _extract_reasoning_chunk(state: ThinkToolProcessorState) -> str | None:
"""
Extract reasoning content from accumulated arguments, stripping JSON wrapper.
Returns the next chunk of reasoning to emit, or None if nothing to emit yet.
"""
# If we haven't found the JSON prefix yet, look for it
if not state.json_prefix_stripped:
# Try both prefix variants
for prefix in [JSON_PREFIX_WITH_SPACE, JSON_PREFIX_NO_SPACE]:
prefix_pos = state.accumulated_args.find(prefix)
if prefix_pos != -1:
# Found prefix - extract content after it
content_start = prefix_pos + len(prefix)
state.buffer = state.accumulated_args[content_start:]
state.accumulated_args = ""
state.json_prefix_stripped = True
break
if not state.json_prefix_stripped:
# Haven't seen full prefix yet, keep accumulating
return None
else:
# Already stripped prefix, add new content to buffer
state.buffer += state.accumulated_args
state.accumulated_args = ""
# Hold back last 2 chars in case they're the JSON suffix "}
if len(state.buffer) <= 2:
return None
# Emit everything except last 2 chars
to_emit = state.buffer[:-2]
state.buffer = state.buffer[-2:]
return to_emit if to_emit else None
def create_think_tool_token_processor() -> (
Callable[[Delta | None, Any], tuple[Delta | None, Any]]
):
"""
Create a custom token processor that converts think_tool calls to reasoning content.
When the think_tool is detected:
- Tool call arguments are converted to reasoning_content (JSON wrapper stripped)
- All other deltas (content, other tool calls) are dropped
This allows non-reasoning models to emit chain-of-thought via the think_tool,
which gets displayed as reasoning tokens in the UI.
Returns:
A function compatible with run_llm_step_pkt_generator's custom_token_processor parameter.
The function takes (Delta, state) and returns (modified Delta | None, new state).
"""
def process_token(delta: Delta | None, state: Any) -> tuple[Delta | None, Any]:
if state is None:
state = ThinkToolProcessorState()
# Handle flush signal (delta=None) - emit the complete tool call
if delta is None:
if state.think_tool_found and state.think_tool_id:
# Return the complete think tool call
complete_tool_call = ChatCompletionDeltaToolCall(
id=state.think_tool_id,
index=state.think_tool_index or 0,
type="function",
function=FunctionCall(
name=THINK_TOOL_NAME,
arguments=state.full_arguments,
),
)
return Delta(tool_calls=[complete_tool_call]), state
return None, state
# Check for think tool in tool_calls
if delta.tool_calls:
for tool_call in delta.tool_calls:
# Detect think tool by name
if tool_call.function and tool_call.function.name == THINK_TOOL_NAME:
state.think_tool_found = True
state.think_tool_index = tool_call.index
# Capture tool call id when available
if (
state.think_tool_found
and tool_call.index == state.think_tool_index
and tool_call.id
):
state.think_tool_id = tool_call.id
# Accumulate arguments for the think tool
if (
state.think_tool_found
and tool_call.index == state.think_tool_index
and tool_call.function
and tool_call.function.arguments
):
# Track full arguments for final tool call
state.full_arguments += tool_call.function.arguments
# Also accumulate for JSON parsing
state.accumulated_args += tool_call.function.arguments
# Try to extract reasoning content
reasoning_chunk = _extract_reasoning_chunk(state)
if reasoning_chunk:
# Return delta with reasoning_content to trigger reasoning streaming
return Delta(reasoning_content=reasoning_chunk), state
# If think tool found, drop all other content
if state.think_tool_found:
return None, state
# No think tool detected, pass through original delta
return delta, state
return process_token
def check_special_tool_calls(tool_calls: list[ToolCallKickoff]) -> SpecialToolCalls:
think_tool_call: ToolCallKickoff | None = None
generate_report_tool_call: ToolCallKickoff | None = None
for tool_call in tool_calls:
if tool_call.tool_name == THINK_TOOL_NAME:
think_tool_call = tool_call
elif tool_call.tool_name == GENERATE_REPORT_TOOL_NAME:
generate_report_tool_call = tool_call
return SpecialToolCalls(
think_tool_call=think_tool_call,
generate_report_tool_call=generate_report_tool_call,
)

View File

@@ -1,4 +1,6 @@
import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
@@ -10,11 +12,13 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
# NOTE: "Document" in the naming convention is used to refer to the entire
# document as represented in Onyx. What is actually stored in the index is the
# document chunks. By the terminology of most search engines / vector databases,
# the individual objects stored are called documents, but in this case it refers
# to a chunk.
# NOTE: "Document" in the naming convention is used to refer to the entire document as represented in Onyx.
# What is actually stored in the index is the document chunks. By the terminology of most search engines / vector
# databases, the individual objects stored are called documents, but in this case it refers to a chunk.
# Outside of searching and update capabilities, the document index must also implement the ability to port all of
# the documents over to a secondary index. This allows for embedding models to be updated and for porting documents
# to happen in the background while the primary index still serves the main traffic.
__all__ = [
@@ -38,7 +42,7 @@ __all__ = [
class DocumentInsertionRecord(BaseModel):
"""
Result of indexing a document.
Result of indexing a document
"""
model_config = {"frozen": True}
@@ -48,11 +52,10 @@ class DocumentInsertionRecord(BaseModel):
class DocumentSectionRequest(BaseModel):
"""Request for a document section or whole document.
If no min_chunk_ind is provided it should start at the beginning of the
document.
If no max_chunk_ind is provided it should go to the end of the document.
"""
Request for a document section or whole document
If no min_chunk_ind is provided it should start at the beginning of the document
If no max_chunk_ind is provided it should go to the end of the document
"""
model_config = {"frozen": True}
@@ -64,37 +67,24 @@ class DocumentSectionRequest(BaseModel):
class IndexingMetadata(BaseModel):
"""
Information about chunk counts for efficient cleaning / updating of document
chunks.
A common pattern to ensure that no chunks are left over is to delete all of
the chunks for a document and then re-index the document. This information
allows us to only delete the extra "tail" chunks when the document has
gotten shorter.
Information about chunk counts for efficient cleaning / updating of document chunks. A common pattern to ensure
that no chunks are left over is to delete all of the chunks for a document and then re-index the document. This
information allows us to only delete the extra "tail" chunks when the document has gotten shorter.
"""
class ChunkCounts(BaseModel):
model_config = {"frozen": True}
old_chunk_cnt: int
new_chunk_cnt: int
model_config = {"frozen": True}
doc_id_to_chunk_cnt_diff: dict[str, ChunkCounts]
# The tuple is (old_chunk_cnt, new_chunk_cnt)
doc_id_to_chunk_cnt_diff: dict[str, tuple[int, int]]
class MetadataUpdateRequest(BaseModel):
"""
Updates to the documents that can happen without there being an update to
the contents of the document.
Updates to the documents that can happen without there being an update to the contents of the document.
"""
document_ids: list[str]
# Passed in to help with potential optimizations of the implementation. The
# keys should be redundant with document_ids.
# Passed in to help with potential optimizations of the implementation
doc_id_to_chunk_cnt: dict[str, int]
# For the ones that are None, there is no update required to that field.
# For the ones that are None, there is no update required to that field
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None
@@ -110,6 +100,17 @@ class SchemaVerifiable(abc.ABC):
all valid in the schema.
"""
def __init__(
self,
index_name: str,
tenant_id: int | None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.index_name = index_name
self.tenant_id = tenant_id
@abc.abstractmethod
def verify_and_create_index_if_necessary(
self,
@@ -130,40 +131,34 @@ class SchemaVerifiable(abc.ABC):
class Indexable(abc.ABC):
"""
Class must implement the ability to index document chunks.
Class must implement the ability to index document chunks
"""
@abc.abstractmethod
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
chunks: Iterator[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
) -> set[DocumentInsertionRecord]:
"""
Takes a list of document chunks and indexes them in the document index.
Takes a list of document chunks and indexes them in the document index. This is often a batch operation
including chunks from multiple documents.
This is often a batch operation including chunks from multiple
documents.
NOTE: When a document is reindexed/updated here and has gotten shorter, it is important to delete the extra
chunks at the end to ensure there are no stale chunks in the index.
NOTE: When a document is reindexed/updated here and has gotten shorter,
it is important to delete the extra chunks at the end to ensure there
are no stale chunks in the index. The implementation should do this.
NOTE: The chunks of a document are never separated into separate index() calls. So there is
no worry of receiving the first 0 through n chunks in one index call and the next n through
m chunks of a document in the next index call.
NOTE: The chunks of a document are never separated into separate index()
calls. So there is no worry of receiving the first 0 through n chunks in
one index call and the next n through m chunks of a document in the next
index call.
Args:
chunks: Document chunks with all of the information needed for
indexing to the document index.
indexing_metadata: Information about chunk counts for efficient
cleaning / updating.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document index.
- indexing_metadata: Information about chunk counts for efficient cleaning / updating
Returns:
List of document ids which map to unique documents and are used for
deduping chunks when updating, as well as if the document is newly
indexed or already existed and just updated.
List of document ids which map to unique documents and are used for deduping chunks
when updating, as well as if the document is newly indexed or already existed and
just updated
"""
raise NotImplementedError
@@ -178,6 +173,7 @@ class Deletable(abc.ABC):
def delete(
self,
db_doc_id: str,
*,
# Passed in in case it helps the efficiency of the delete implementation
chunk_count: int | None,
) -> int:
@@ -315,10 +311,10 @@ class DocumentIndex(
abc.ABC,
):
"""
A valid document index that can plug into all Onyx flows must implement all
of these functionalities.
A valid document index that can plug into all Onyx flows must implement all of these
functionalities.
As a high-level summary, document indices need to be able to:
As a high level summary, document indices need to be able to
- Verify the schema definition is valid
- Index new documents
- Update specific attributes of existing documents

View File

@@ -39,14 +39,6 @@ def delete_vespa_chunks(
http_client: httpx.Client,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
"""Deletes a list of chunks from a Vespa index in parallel.
Args:
doc_chunk_ids: List of chunk IDs to delete.
index_name: Name of the index to delete from.
http_client: HTTP client to use for the request.
executor: Executor to use for the request.
"""
external_executor = True
if not executor:

View File

@@ -36,9 +36,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import (
DocumentInsertionRecord as OldDocumentInsertionRecord,
)
from onyx.document_index.interfaces import DocumentInsertionRecord
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
@@ -46,7 +44,6 @@ from onyx.document_index.interfaces import UpdateRequest
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces import VespaDocumentFields
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
from onyx.document_index.vespa.chunk_retrieval import (
parallel_visit_api_retrieval,
@@ -54,7 +51,9 @@ from onyx.document_index.vespa.chunk_retrieval import (
from onyx.document_index.vespa.chunk_retrieval import query_vespa
from onyx.document_index.vespa.deletion import delete_vespa_chunks
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
@@ -64,8 +63,6 @@ from onyx.document_index.vespa.shared_utils.utils import (
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa.vespa_document_index import TenantState
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BATCH_SIZE
from onyx.document_index.vespa_constants import BOOST
@@ -256,10 +253,8 @@ class VespaIndex(DocumentIndex):
self.multitenant = multitenant
# Temporary until we refactor the entirety of this class.
self.httpx_client = httpx_client
self.httpx_client_context: BaseHTTPXClientContext
if httpx_client:
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
else:
@@ -480,45 +475,92 @@ class VespaIndex(DocumentIndex):
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
raise ValueError("Bug: Length of doc ID to chunk maps does not match.")
doc_id_to_chunk_cnt_diff = {
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=index_batch_params.doc_id_to_previous_chunk_cnt[doc_id],
new_chunk_cnt=index_batch_params.doc_id_to_new_chunk_cnt[doc_id],
) -> set[DocumentInsertionRecord]:
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
with updating the associated permissions. Assumes that a document will not be split into
multiple chunk batches calling this function multiple times, otherwise only the last set of
chunks will be kept"""
doc_id_to_previous_chunk_cnt = index_batch_params.doc_id_to_previous_chunk_cnt
doc_id_to_new_chunk_cnt = index_batch_params.doc_id_to_new_chunk_cnt
tenant_id = index_batch_params.tenant_id
large_chunks_enabled = index_batch_params.large_chunks_enabled
# IMPORTANT: This must be done one index at a time, do not use secondary index here
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
# needed so the final DocumentInsertionRecord returned can have the original document ID
new_document_id_to_original_document_id: dict[str, str] = {}
for ind, chunk in enumerate(cleaned_chunks):
old_chunk = chunks[ind]
new_document_id_to_original_document_id[chunk.source_document.id] = (
old_chunk.source_document.id
)
for doc_id in index_batch_params.doc_id_to_previous_chunk_cnt.keys()
}
indexing_metadata = IndexingMetadata(
doc_id_to_chunk_cnt_diff=doc_id_to_chunk_cnt_diff,
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=TenantState(
tenant_id=index_batch_params.tenant_id,
multitenant=self.multitenant,
),
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
# This conversion from list to set only to be converted again to a list
# upstream is suboptimal and only temporary until we refactor the
# entirety of this class.
document_insertion_records = vespa_document_index.index(
chunks, indexing_metadata
)
return set(
[
OldDocumentInsertionRecord(
document_id=doc_insertion_record.document_id,
already_existed=doc_insertion_record.already_existed,
existing_docs: set[str] = set()
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
self.httpx_client_context as http_client,
):
# We require the start and end index for each document in order to
# know precisely which chunks to delete. This information exists for
# documents that have `chunk_count` in the database, but not for
# `old_version` documents.
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = [
VespaIndex.enrich_basic_chunk_info(
index_name=self.index_name,
http_client=http_client,
document_id=doc_id,
previous_chunk_count=doc_id_to_previous_chunk_cnt.get(doc_id, 0),
new_chunk_count=doc_id_to_new_chunk_cnt.get(doc_id, 0),
)
for doc_insertion_record in document_insertion_records
for doc_id in doc_id_to_new_chunk_cnt.keys()
]
)
for cleaned_doc_info in enriched_doc_infos:
# If the document has previously indexed chunks, we know it previously existed
if cleaned_doc_info.chunk_end_index:
existing_docs.add(cleaned_doc_info.doc_id)
# Now, for each doc, we know exactly where to start and end our deletion
# So let's generate the chunk IDs for each chunk to delete
chunks_to_delete = get_document_chunk_ids(
enriched_document_info_list=enriched_doc_infos,
tenant_id=tenant_id,
large_chunks_enabled=large_chunks_enabled,
)
# Delete old Vespa documents
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
delete_vespa_chunks(
doc_chunk_ids=doc_chunk_ids_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
multitenant=self.multitenant,
executor=executor,
)
all_cleaned_doc_ids = {chunk.source_document.id for chunk in cleaned_chunks}
return {
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
already_existed=cleaned_doc_id in existing_docs,
)
for cleaned_doc_id in all_cleaned_doc_ids
}
@classmethod
def _apply_updates_batched(

View File

@@ -244,15 +244,6 @@ def batch_index_vespa_chunks(
multitenant: bool,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
"""Indexes a list of chunks in a Vespa index in parallel.
Args:
chunks: List of chunks to index.
index_name: Name of the index to index into.
http_client: HTTP client to use for the request.
multitenant: Whether the index is multitenant.
executor: Executor to use for the request.
"""
external_executor = True
if not executor:

View File

@@ -64,9 +64,10 @@ def remove_invalid_unicode_chars(text: str) -> str:
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
"""
Configures and returns an HTTP client for communicating with Vespa,
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.
"""
return httpx.Client(
cert=(
cast(tuple[str, str], (VESPA_CLOUD_CERT_PATH, VESPA_CLOUD_KEY_PATH))

View File

@@ -1,287 +0,0 @@
import concurrent.futures
import httpx
from pydantic import BaseModel
from onyx.context.search.enums import QueryType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
from onyx.document_index.interfaces_new import DocumentIndex
from onyx.document_index.interfaces_new import DocumentInsertionRecord
from onyx.document_index.interfaces_new import DocumentSectionRequest
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.vespa.deletion import delete_vespa_chunks
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa_constants import BATCH_SIZE
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.utils.batching import batch_generator
from shared_configs.model_server_models import Embedding
class TenantState(BaseModel):
"""
Captures the tenant-related state for an instance of VespaDocumentIndex.
TODO(andrei): If we find that we need this for Opensearch too, just move
this to interfaces_new.py.
"""
model_config = {"frozen": True}
tenant_id: str
multitenant: bool
def _enrich_basic_chunk_info(
index_name: str,
http_client: httpx.Client,
document_id: str,
previous_chunk_count: int | None,
new_chunk_count: int,
) -> EnrichedDocumentIndexingInfo:
"""Determines which chunks need to be deleted during document reindexing.
When a document is reindexed, it may have fewer chunks than before. This
function identifies the range of old chunks that need to be deleted by
comparing the new chunk count with the previous chunk count.
Example:
If a document previously had 10 chunks (0-9) and now has 7 chunks (0-6),
this function identifies that chunks 7-9 need to be deleted.
Args:
index_name: The Vespa index/schema name.
http_client: HTTP client for making requests to Vespa.
document_id: The Vespa-sanitized ID of the document being reindexed.
previous_chunk_count: The total number of chunks the document had before
reindexing. None for documents using the legacy chunk ID system.
new_chunk_count: The total number of chunks the document has after
reindexing. This becomes the starting index for deletion since
chunks are 0-indexed.
Returns:
EnrichedDocumentIndexingInfo with chunk_start_index set to
new_chunk_count (where deletion begins) and chunk_end_index set to
previous_chunk_count (where deletion ends).
"""
# Technically last indexed chunk index +1.
last_indexed_chunk = previous_chunk_count
# If the document has no `chunk_count` in the database, we know that it
# has the old chunk ID system and we must check for the final chunk index.
is_old_version = False
if last_indexed_chunk is None:
is_old_version = True
minimal_doc_info = MinimalDocumentIndexingInfo(
doc_id=document_id, chunk_start_index=new_chunk_count
)
last_indexed_chunk = check_for_final_chunk_existence(
minimal_doc_info=minimal_doc_info,
start_index=new_chunk_count,
index_name=index_name,
http_client=http_client,
)
enriched_doc_info = EnrichedDocumentIndexingInfo(
doc_id=document_id,
chunk_start_index=new_chunk_count,
chunk_end_index=last_indexed_chunk,
old_version=is_old_version,
)
return enriched_doc_info
class VespaDocumentIndex(DocumentIndex):
"""Vespa-specific implementation of the DocumentIndex interface.
This class provides document indexing, retrieval, and management operations
for a Vespa search engine instance. It handles the complete lifecycle of
document chunks within a specific Vespa index/schema.
"""
def __init__(
self,
index_name: str,
tenant_state: TenantState,
large_chunks_enabled: bool,
httpx_client: httpx.Client | None = None,
) -> None:
self._index_name = index_name
self._tenant_id = tenant_state.tenant_id
self._large_chunks_enabled = large_chunks_enabled
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This
# is beneficial for indexing / updates / deletes since we have to make a
# large volume of requests.
self._httpx_client_context: BaseHTTPXClientContext
if httpx_client:
# Use the provided client. Because this client is presumed global,
# it does not close after exiting a context manager.
self._httpx_client_context = GlobalHTTPXClientContext(httpx_client)
else:
# We did not receive a client, so create one what will close after
# exiting a context manager.
self._httpx_client_context = TemporaryHTTPXClientContext(
get_vespa_http_client
)
self._multitenant = tenant_state.multitenant
def verify_and_create_index_if_necessary(
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
) -> None:
raise NotImplementedError
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
indexing_metadata: IndexingMetadata,
) -> list[DocumentInsertionRecord]:
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
doc_id_to_previous_chunk_cnt = {
doc_id: chunk_cnt_diff.old_chunk_cnt
for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items()
}
doc_id_to_new_chunk_cnt = {
doc_id: chunk_cnt_diff.new_chunk_cnt
for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items()
}
assert (
len(doc_id_to_chunk_cnt_diff)
== len(doc_id_to_previous_chunk_cnt)
== len(doc_id_to_new_chunk_cnt)
), "Bug: Doc ID to chunk maps have different lengths."
# Vespa has restrictions on valid characters, yet document IDs come from
# external w.r.t. this class. We need to sanitize them.
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
clean_chunk_id_copy(chunk) for chunk in chunks
]
assert len(cleaned_chunks) == len(
chunks
), "Bug: Cleaned chunks and input chunks have different lengths."
# Needed so the final DocumentInsertionRecord returned can have the
# original document ID. cleaned_chunks might not contain IDs exactly as
# callers supplied them.
new_document_id_to_original_document_id: dict[str, str] = dict()
for i, cleaned_chunk in enumerate(cleaned_chunks):
old_chunk = chunks[i]
new_document_id_to_original_document_id[
cleaned_chunk.source_document.id
] = old_chunk.source_document.id
existing_docs: set[str] = set()
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
self._httpx_client_context as http_client,
):
# We require the start and end index for each document in order to
# know precisely which chunks to delete. This information exists for
# documents that have `chunk_count` in the database, but not for
# `old_version` documents.
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = [
_enrich_basic_chunk_info(
index_name=self._index_name,
http_client=http_client,
document_id=doc_id,
previous_chunk_count=doc_id_to_previous_chunk_cnt[doc_id],
new_chunk_count=doc_id_to_new_chunk_cnt[doc_id],
)
for doc_id in doc_id_to_chunk_cnt_diff.keys()
]
for enriched_doc_info in enriched_doc_infos:
# If the document has previously indexed chunks, we know it
# previously existed and this is a reindex.
if enriched_doc_info.chunk_end_index:
existing_docs.add(enriched_doc_info.doc_id)
# Now, for each doc, we know exactly where to start and end our
# deletion. So let's generate the chunk IDs for each chunk to
# delete.
# WARNING: This code seems to use
# indexing_metadata.doc_id_to_chunk_cnt_diff as the source of truth
# for which chunks to delete. This implies that the onus is on the
# caller to ensure doc_id_to_chunk_cnt_diff only contains docs
# relevant to the chunks argument to this method. This should not be
# the contract of DocumentIndex; and this code is only a refactor
# from old code. It would seem we should use all_cleaned_doc_ids as
# the source of truth.
chunks_to_delete = get_document_chunk_ids(
enriched_document_info_list=enriched_doc_infos,
tenant_id=self._tenant_id, # TODO: Figure out this typing bro wtf.
large_chunks_enabled=self._large_chunks_enabled,
)
# Delete old Vespa documents.
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
delete_vespa_chunks(
doc_chunk_ids=doc_chunk_ids_batch,
index_name=self._index_name,
http_client=http_client,
executor=executor,
)
# Insert new Vespa documents.
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(
chunks=chunk_batch,
index_name=self._index_name,
http_client=http_client,
multitenant=self._multitenant,
executor=executor,
)
all_cleaned_doc_ids: set[str] = {
chunk.source_document.id for chunk in cleaned_chunks
}
return [
DocumentInsertionRecord(
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
already_existed=cleaned_doc_id in existing_docs,
)
for cleaned_doc_id in all_cleaned_doc_ids
]
def delete(self, db_doc_id: str, chunk_count: int | None) -> int:
raise NotImplementedError
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
raise NotImplementedError
def id_based_retrieval(
self, chunk_requests: list[DocumentSectionRequest]
) -> list[InferenceChunk]:
raise NotImplementedError
def hybrid_retrieval(
self,
query: str,
query_embedding: Embedding,
final_keywords: list[str] | None,
query_type: QueryType,
filters: IndexFilters,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunk]:
raise NotImplementedError
def random_retrieval(
self,
filters: IndexFilters | None = None,
num_to_retrieve: int = 100,
dirty: bool | None = None,
) -> list[InferenceChunk]:
raise NotImplementedError

View File

@@ -26,9 +26,9 @@ DOCUMENT_ID_ENDPOINT = (
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
# Since Vespa doesn't allow batching of inserts / updates, we use threads to
# parallelize the operations.
NUM_THREADS = 32
NUM_THREADS = (
32 # since Vespa doesn't allow batching of inserts / updates, we use threads
)
MAX_ID_SEARCH_QUERY_SIZE = 400
# Suspect that adding too many "or" conditions will cause Vespa to timeout and return
# an empty list of hits (with no error status and coverage: 0 and degraded)
@@ -37,9 +37,7 @@ MAX_OR_CONDITIONS = 10
# in the long term, we are looking to improve the performance of Vespa
# so that we can bring this back to default
VESPA_TIMEOUT = "10s"
# The size of the batch to use for batched operations like inserts / updates.
# The batch will likely be sent to a threadpool of size NUM_THREADS.
BATCH_SIZE = 128
BATCH_SIZE = 128 # Specific to Vespa
TENANT_ID = "tenant_id"
DOCUMENT_ID = "document_id"

View File

@@ -8,6 +8,8 @@ from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
@@ -23,21 +25,65 @@ from PIL import Image
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.file_types import PRESENTATION_MIME_TYPE
from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE
from onyx.file_processing.file_validation import TEXT_MIME_TYPE
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.utils.file_types import PRESENTATION_MIME_TYPE
from onyx.utils.file_types import WORD_PROCESSING_MIME_TYPE
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from markitdown import MarkItDown
logger = setup_logger()
# NOTE(rkuo): Unify this with upload_files_for_chat and file_valiation.py
TEXT_SECTION_SEPARATOR = "\n\n"
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
".conf",
".log",
".json",
".csv",
".tsv",
".xml",
".yml",
".yaml",
".sql",
]
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
".xlsx",
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
"image/webp",
]
_MARKITDOWN_CONVERTER: Optional["MarkItDown"] = None
KNOWN_OPENPYXL_BUGS = [
@@ -56,11 +102,42 @@ def get_markitdown_converter() -> "MarkItDown":
return _MARKITDOWN_CONVERTER
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
_, extension = os.path.splitext(file_path_or_name)
return extension.lower()
def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_text_file(file: IO[bytes]) -> bool:
"""
checks if the first 1024 bytes only contain printable or whitespace characters
@@ -497,7 +574,9 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)
@@ -548,23 +627,6 @@ def extract_text_and_images(
"""
Primary new function for the updated connector.
Returns structured extraction result with text content, embedded images, and metadata.
Args:
file: File-like object to extract content from.
file_name: Name of the file (used to determine extension/type).
pdf_pass: Optional password for encrypted PDFs.
content_type: Optional MIME type override for the file.
image_callback: Optional callback for streaming image extraction. When provided,
embedded images are passed to this callback one at a time as (bytes, filename)
instead of being accumulated in the returned ExtractionResult.embedded_images
list. This is a memory optimization for large documents with many images -
the caller can process/store each image immediately rather than holding all
images in memory. When using a callback, ExtractionResult.embedded_images
will be an empty list.
Returns:
ExtractionResult containing text_content, embedded_images (empty if callback used),
and metadata extracted from the file.
"""
res = _extract_text_and_images(
file, file_name, pdf_pass, content_type, image_callback
@@ -601,7 +663,7 @@ def _extract_text_and_images(
# with content types in UploadMimeTypes.DOCUMENT_MIME_TYPES as plain text files.
# As a result, the file name extension may differ from the original content type.
# We process files with a plain text content type first to handle this scenario.
if content_type in OnyxMimeTypes.TEXT_MIME_TYPES:
if content_type == TEXT_MIME_TYPE:
return extract_result_from_text_file(file)
# Default processing
@@ -663,7 +725,7 @@ def _extract_text_and_images(
)
# If we reach here and it's a recognized text extension
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
if is_text_file_extension(file_name):
return extract_result_from_text_file(file)
# If it's an image file or something else, we do not parse embedded images from them

View File

@@ -1,84 +0,0 @@
PRESENTATION_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
SPREADSHEET_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
WORD_PROCESSING_MIME_TYPE = (
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
PDF_MIME_TYPE = "application/pdf"
PLAIN_TEXT_MIME_TYPE = "text/plain"
class OnyxMimeTypes:
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
CSV_MIME_TYPES = {"text/csv"}
TEXT_MIME_TYPES = {
PLAIN_TEXT_MIME_TYPE,
"text/markdown",
"text/x-markdown",
"text/x-config",
"text/tab-separated-values",
"application/json",
"application/xml",
"text/xml",
"application/x-yaml",
}
DOCUMENT_MIME_TYPES = {
PDF_MIME_TYPE,
WORD_PROCESSING_MIME_TYPE,
PRESENTATION_MIME_TYPE,
SPREADSHEET_MIME_TYPE,
"message/rfc822",
"application/epub+zip",
}
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
)
EXCLUDED_IMAGE_TYPES = {
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
}
class OnyxFileExtensions:
PLAIN_TEXT_EXTENSIONS = {
".txt",
".md",
".mdx",
".conf",
".log",
".json",
".csv",
".tsv",
".xml",
".yml",
".yaml",
".sql",
}
DOCUMENT_EXTENSIONS = {
".pdf",
".docx",
".pptx",
".xlsx",
".eml",
".epub",
".html",
}
IMAGE_EXTENSIONS = {
".png",
".jpg",
".jpeg",
".webp",
}
TEXT_AND_DOCUMENT_EXTENSIONS = PLAIN_TEXT_EXTENSIONS.union(DOCUMENT_EXTENSIONS)
ALL_ALLOWED_EXTENSIONS = TEXT_AND_DOCUMENT_EXTENSIONS.union(IMAGE_EXTENSIONS)

View File

@@ -0,0 +1,55 @@
"""
Centralized file type validation utilities.
"""
# NOTE(rkuo): Unify this with upload_files_for_chat and extract_file_text
# Standard image MIME types supported by most vision LLMs
IMAGE_MIME_TYPES = [
"image/png",
"image/jpeg",
"image/jpg",
"image/webp",
]
# Image types that should be excluded from processing
EXCLUDED_IMAGE_TYPES = [
"image/bmp",
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]
# Text MIME types
TEXT_MIME_TYPE = "text/plain"
def is_valid_image_type(mime_type: str) -> bool:
"""
Check if mime_type is a valid image type.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is a valid image type, False otherwise
"""
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
def is_supported_by_vision_llm(mime_type: str) -> bool:
"""
Check if this image type can be processed by vision LLMs.
Args:
mime_type: The MIME type to check
Returns:
True if the MIME type is supported by vision LLMs, False otherwise
"""
return mime_type in IMAGE_MIME_TYPES

View File

@@ -126,15 +126,6 @@ class FileStore(ABC):
Read the file record by the ID
"""
@abstractmethod
def get_file_size(
self, file_id: str, db_session: Session | None = None
) -> int | None:
"""
Get the size of a file in bytes.
Optionally provide a db_session for database access.
"""
@abstractmethod
def delete_file(self, file_id: str) -> None:
"""
@@ -431,27 +422,6 @@ class S3BackedFileStore(FileStore):
)
return file_record
def get_file_size(
self, file_id: str, db_session: Session | None = None
) -> int | None:
"""
Get the size of a file in bytes by querying S3 metadata.
"""
try:
with get_session_with_current_tenant_if_none(db_session) as db_session:
file_record = get_filerecord_by_file_id(
file_id=file_id, db_session=db_session
)
s3_client = self._get_s3_client()
response = s3_client.head_object(
Bucket=file_record.bucket_name, Key=file_record.object_key
)
return response.get("ContentLength")
except Exception as e:
logger.warning(f"Error getting file size for {file_id}: {e}")
return None
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:

View File

@@ -23,7 +23,7 @@ from onyx.indexing.models import BuildMetadataAwareChunksResult
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
@@ -140,7 +140,7 @@ class UserFileIndexingAdapter:
# Initialize tokenizer used for token count calculation
try:
llm = get_default_llm()
llm, _ = get_default_llms()
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,

View File

@@ -28,7 +28,7 @@ from onyx.kg.utils.formatting_utils import make_entity_id
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.kg.utils.formatting_utils import make_relationship_type_id
from onyx.kg.vespa.vespa_interactions import get_document_vespa_contents
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import llm_response_to_string
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
@@ -415,9 +415,9 @@ def kg_classify_document(
)
# classify with LLM
llm = get_default_llm()
primary_llm, _ = get_default_llms()
try:
raw_classification_result = llm_response_to_string(llm.invoke(prompt))
raw_classification_result = llm_response_to_string(primary_llm.invoke(prompt))
classification_result = (
raw_classification_result.replace("```json", "").replace("```", "").strip()
)
@@ -479,9 +479,9 @@ def kg_deep_extract_chunks(
).replace("---content---", llm_context)
# extract with LLM
llm = get_default_llm()
_, fast_llm = get_default_llms()
try:
raw_extraction_result = llm_response_to_string(llm.invoke(prompt))
raw_extraction_result = llm_response_to_string(fast_llm.invoke(prompt))
cleaned_response = (
raw_extraction_result.replace("{{", "{")
.replace("}}", "}")

View File

@@ -258,9 +258,6 @@ class LitellmLLM(LLM):
)
# Needed to get reasoning tokens from the model
# NOTE: OpenAI Responses API is disabled for parallel tool calls because LiteLLM's transformation layer
# doesn't properly pass parallel_tool_calls to the API, causing the model to
# always return sequential tool calls. For this reason parallel tool calls won't work with OpenAI models
if (
is_true_openai_model(self.config.model_provider, self.config.model_name)
or self.config.model_provider == AZURE_PROVIDER_NAME
@@ -293,7 +290,7 @@ class LitellmLLM(LLM):
completion_kwargs["metadata"] = metadata
try:
response = litellm.completion(
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
# model choice
# model="openai/gpt-4",
@@ -314,9 +311,24 @@ class LitellmLLM(LLM):
temperature=(1 if is_reasoning else self._temperature),
timeout=timeout_override or self._timeout,
**({"stream_options": {"include_usage": True}} if stream else {}),
# NOTE: we can't pass parallel_tool_calls if tools are not specified
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": parallel_tool_calls} if tools else {}),
**(
{"parallel_tool_calls": parallel_tool_calls}
if tools
and self.config.model_name
not in [
"o3-mini",
"o3-preview",
"o1",
"o1-preview",
"o1-mini",
"o1-mini-2024-09-12",
"o3-mini-2025-01-31",
]
else {}
),
# Anthropic Claude uses `thinking` with budget_tokens for extended thinking
# This applies to Claude models on any provider (anthropic, vertex_ai, bedrock)
**(
@@ -338,7 +350,10 @@ class LitellmLLM(LLM):
# (litellm maps this to thinking_level for Gemini 3 models)
**(
{"reasoning_effort": OPENAI_REASONING_EFFORT[reasoning_effort]}
if is_reasoning and "claude" not in self.config.model_name.lower()
if reasoning_effort
and reasoning_effort != ReasoningEffort.OFF
and is_reasoning
and "claude" not in self.config.model_name.lower()
else {}
),
**(
@@ -349,7 +364,6 @@ class LitellmLLM(LLM):
**({self._max_token_param: max_tokens} if max_tokens else {}),
**completion_kwargs,
)
return response
except Exception as e:
self._record_error(prompt, e)

View File

@@ -54,6 +54,12 @@ def _build_provider_extra_headers(
return {}
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
return llms[0]
def get_llm_config_for_persona(
persona: Persona,
db_session: Session,
@@ -101,16 +107,16 @@ def get_llm_config_for_persona(
)
def get_llm_for_persona(
def get_llms_for_persona(
persona: Persona | PersonaOverrideConfig | None,
user: User | None,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
) -> tuple[LLM, LLM]:
if persona is None:
logger.warning("No persona provided, using default LLM")
return get_default_llm()
logger.warning("No persona provided, using default LLMs")
return get_default_llms()
provider_name_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
@@ -118,7 +124,7 @@ def get_llm_for_persona(
provider_name = provider_name_override or persona.llm_model_provider_override
if not provider_name:
return get_default_llm(
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
@@ -147,7 +153,7 @@ def get_llm_for_persona(
getattr(persona_model, "id", None),
provider_model.name,
)
return get_default_llm(
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
@@ -156,24 +162,30 @@ def get_llm_for_persona(
llm_provider = LLMProviderView.from_model(provider_model)
model = model_version_override or persona.llm_model_version_override
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
if not model:
raise ValueError("No model name found")
if not fast_model:
raise ValueError("No fast model name found")
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
temperature=temperature_override,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
max_input_tokens=get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model
),
)
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
temperature=temperature_override,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
max_input_tokens=get_max_input_tokens_from_llm_provider(
llm_provider=llm_provider, model_name=model
),
)
return _create_llm(model), _create_llm(fast_model)
def get_default_llm_with_vision(
@@ -235,7 +247,7 @@ def get_default_llm_with_vision(
):
return create_vision_llm(provider_view, provider.default_vision_model)
# If no model-configurations are specified, try default model
# If no model-configurations are specified, try default models in priority order
if not provider.model_configurations:
# Try default_model_name
if provider.default_model_name and model_supports_image_input(
@@ -243,6 +255,14 @@ def get_default_llm_with_vision(
):
return create_vision_llm(provider_view, provider.default_model_name)
# Try fast_default_model_name
if provider.fast_default_model_name and model_supports_image_input(
provider.fast_default_model_name, provider.provider
):
return create_vision_llm(
provider_view, provider.fast_default_model_name
)
# Otherwise, if model-configurations are specified, check each model
else:
for model_configuration in provider.model_configurations:
@@ -291,12 +311,12 @@ def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
)
def get_default_llm(
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
) -> tuple[LLM, LLM]:
with get_session_with_current_tenant() as db_session:
llm_provider = fetch_default_provider(db_session)
@@ -304,17 +324,25 @@ def get_default_llm(
raise ValueError("No default LLM provider found")
model_name = llm_provider.default_model_name
fast_model_name = (
llm_provider.fast_default_model_name or llm_provider.default_model_name
)
if not model_name:
raise ValueError("No default model name found")
if not fast_model_name:
raise ValueError("No fast default model name found")
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
def _create_llm(model: str) -> LLM:
return llm_from_provider(
model_name=model,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return _create_llm(model_name), _create_llm(fast_model_name)
def get_llm(

View File

@@ -47,6 +47,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
custom_config_keys: list[CustomConfigKey] | None = None
model_configurations: list[ModelConfigurationView]
default_model: str | None = None
default_fast_model: str | None = None
default_api_base: str | None = None
# set for providers like Azure, which require a deployment name.
deployment_name_required: bool = False
@@ -147,6 +148,7 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_LOCATION_KWARG = "vertex_location"
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.5-flash-lite"
# Curated list of Vertex AI models to show by default in the UI
VERTEXAI_VISIBLE_MODEL_NAMES = {
"gemini-2.5-flash",
@@ -333,6 +335,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
OPENAI_PROVIDER_NAME
),
default_model="gpt-4o",
default_fast_model="gpt-4o-mini",
),
WellKnownLLMProviderDescriptor(
name=OLLAMA_PROVIDER_NAME,
@@ -354,6 +357,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
OLLAMA_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
default_api_base="http://127.0.0.1:11434",
),
WellKnownLLMProviderDescriptor(
@@ -368,6 +372,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
ANTHROPIC_PROVIDER_NAME
),
default_model="claude-sonnet-4-5-20250929",
default_fast_model="claude-sonnet-4-20250514",
),
WellKnownLLMProviderDescriptor(
name=AZURE_PROVIDER_NAME,
@@ -450,6 +455,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
BEDROCK_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
),
WellKnownLLMProviderDescriptor(
name=VERTEXAI_PROVIDER_NAME,
@@ -482,6 +488,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
),
],
default_model=VERTEXAI_DEFAULT_MODEL,
default_fast_model=VERTEXAI_DEFAULT_FAST_MODEL,
),
WellKnownLLMProviderDescriptor(
name=OPENROUTER_PROVIDER_NAME,
@@ -495,6 +502,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
OPENROUTER_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
default_api_base="https://openrouter.ai/api/v1",
),
]

View File

@@ -35,11 +35,8 @@ CLAUDE_REASONING_BUDGET_TOKENS: dict[ReasoningEffort, int] = {
}
# OpenAI reasoning effort mapping (direct string values)
# TODO this needs to be cleaned up, there is a lot of jank and unnecessary slowness
# Also there should be auto for reasoning level which is not used here.
OPENAI_REASONING_EFFORT: dict[ReasoningEffort | None, str] = {
None: "medium", # Seems there is no auto mode in this version unfortunately
ReasoningEffort.OFF: "low", # Issues with 5.2 models not supporting minimal or off with this version of litellm
OPENAI_REASONING_EFFORT: dict[ReasoningEffort, str] = {
ReasoningEffort.OFF: "none", # this only works for the 5 series though
ReasoningEffort.LOW: "low",
ReasoningEffort.MEDIUM: "medium",
ReasoningEffort.HIGH: "high",

View File

@@ -34,7 +34,7 @@ from onyx.configs.onyxbot_configs import (
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import llm_response_to_string
from onyx.onyxbot.slack.constants import FeedbackVisibility
from onyx.onyxbot.slack.models import ChannelType
@@ -141,7 +141,7 @@ def check_message_limit() -> bool:
def rephrase_slack_message(msg: str) -> str:
llm = get_default_llm(timeout=5)
llm, _ = get_default_llms(timeout=5)
prompt = SLACK_LANGUAGE_REPHRASE_PROMPT.format(query=msg)
model_output = llm_response_to_string(llm.invoke(prompt))
logger.debug(model_output)

View File

@@ -6,8 +6,7 @@ from onyx.prompts.deep_research.dr_tool_prompts import THINK_TOOL_NAME
# ruff: noqa: E501, W605 start
CLARIFICATION_PROMPT = f"""
You are a clarification agent that runs prior to deep research. Assess whether you need to ask clarifying questions, or if the user has already provided enough information for you to start research. \
CRITICAL - Never directly answer the user's query, you must only ask clarifying questions or call the `{GENERATE_PLAN_TOOL_NAME}` tool.
You are a clarification agent that runs prior to deep research. Assess whether you need to ask clarifying questions, or if the user has already provided enough information for you to start research. Clarifications are generally helpful.
If the user query is already very detailed or lengthy (more than 3 sentences), do not ask for clarification and instead call the `{GENERATE_PLAN_TOOL_NAME}` tool.
@@ -32,7 +31,7 @@ Focus on providing a thorough research of the user's query over being helpful.
For context, the date is {current_datetime}.
The research plan should be formatted as a numbered list of steps and have 6 or less individual steps.
The research plan should be formatted as a numbered list of steps and have less than 7 individual steps.
Each step should be a standalone exploration question or topic that can be researched independently but may build on previous steps.
@@ -110,7 +109,7 @@ Provide inline citations in the format [1], [2], [3], etc. based on the citation
"""
USER_FINAL_REPORT_QUERY = f"""
USER_FINAL_REPORT_QUERY = """
Provide a comprehensive answer to my previous query. CRITICAL: be as detailed as possible, stay on topic, and provide clear organization in your response.
Ignore the format styles of the intermediate {RESEARCH_AGENT_TOOL_NAME} reports, those are not end user facing and different from your task.

View File

@@ -60,8 +60,4 @@ GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name was misspelled.
""".strip()
# ruff: noqa: E501, W605 end

View File

@@ -1,4 +1,4 @@
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import llm_response_to_string
from onyx.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ def get_answer_validity(
return False
return True # If something is wrong, let's not toss away the answer
llm = get_default_llm()
llm, _ = get_default_llms()
prompt = ANSWER_VALIDITY_PROMPT.format(user_query=query, llm_answer=answer)
model_output = llm_response_to_string(llm.invoke(prompt))

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable
from onyx.configs.constants import MessageType
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.models import AssistantMessage
from onyx.llm.models import ChatCompletionMessage
@@ -224,11 +224,11 @@ def keyword_query_expansion(
def llm_multilingual_query_expansion(query: str, language: str) -> str:
llm = get_default_llm(timeout=5)
_, fast_llm = get_default_llms(timeout=5)
prompt = LANGUAGE_REPHRASE_PROMPT.format(query=query, target_language=language)
model_output = llm_response_to_string(
llm.invoke(prompt, reasoning_effort=ReasoningEffort.OFF)
fast_llm.invoke(prompt, reasoning_effort=ReasoningEffort.OFF)
)
logger.debug(model_output)

View File

@@ -15,7 +15,7 @@ from onyx.db.models import StarterMessage
from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import llm_response_to_string
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
@@ -74,7 +74,7 @@ def generate_start_message_prompts(
categories: List[str | None],
chunk_contents: str,
supports_structured_output: bool,
llm: Any,
fast_llm: Any,
) -> List[FunctionCall]:
"""
Generates the list of FunctionCall objects for starter message generation.
@@ -99,7 +99,7 @@ def generate_start_message_prompts(
functions.append(
FunctionCall(
llm.invoke,
fast_llm.invoke,
(start_message_generation_prompt,),
)
)
@@ -119,12 +119,12 @@ def generate_starter_messages(
Generates starter messages by first obtaining categories and then generating messages for each category.
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
"""
llm = get_default_llm(temperature=0.5)
_, fast_llm = get_default_llms(temperature=0.5)
from litellm.utils import get_supported_openai_params
provider = llm.config.model_provider
model = llm.config.model_name
provider = fast_llm.config.model_provider
model = fast_llm.config.model_name
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
supports_structured_output = (
@@ -142,7 +142,7 @@ def generate_starter_messages(
num_categories=generation_count,
)
category_response = llm.invoke(category_generation_prompt)
category_response = fast_llm.invoke(category_generation_prompt)
response_content = llm_response_to_string(category_response)
categories = parse_categories(response_content)
@@ -179,7 +179,7 @@ def generate_starter_messages(
categories,
chunk_contents,
supports_structured_output,
llm,
fast_llm,
)
# Run LLM calls in parallel

View File

@@ -1,3 +1,4 @@
import io
import json
import math
import mimetypes
@@ -9,8 +10,6 @@ from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
@@ -24,9 +23,6 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
@@ -111,15 +107,13 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
from onyx.server.documents.models import ConnectorBase
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.documents.models import ConnectorFileInfo
from onyx.server.documents.models import ConnectorFilesResponse
from onyx.server.documents.models import ConnectorIndexingStatusLite
from onyx.server.documents.models import ConnectorIndexingStatusLiteResponse
from onyx.server.documents.models import ConnectorSnapshot
@@ -142,8 +136,9 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.documents.models import SourceSummary
from onyx.server.federated.models import FederatedConnectorStatus
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import CallableProtocol
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -461,6 +456,9 @@ def is_zip_file(file: UploadFile) -> bool:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
# Skip directories and known macOS metadata entries
def should_process_file(file_path: str) -> bool:
@@ -474,10 +472,6 @@ def upload_files(
file_store = get_default_file_store()
seen_zip = False
for file in files:
if not file.filename:
logger.warning("File has no filename, skipping")
continue
if is_zip_file(file):
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
@@ -507,6 +501,24 @@ def upload_files(
deduped_file_names.append(os.path.basename(file_info))
continue
# For mypy, actual check happens at start of function
assert file.filename is not None
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue
# Default handling for all other file types
file_id = file_store.save_file(
content=file.file,
display_name=file.filename,
@@ -525,17 +537,6 @@ def upload_files(
)
def _normalize_file_names_for_backwards_compatibility(
file_locations: list[str], file_names: list[str]
) -> list[str]:
"""
Ensures file_names list is the same length as file_locations for backwards compatibility.
In legacy data, file_names might not exist or be shorter than file_locations.
If file_names is shorter, pads it with corresponding file_locations values.
"""
return file_names + file_locations[len(file_names) :]
@router.post("/admin/connector/file/upload")
def upload_files_api(
files: list[UploadFile],
@@ -544,229 +545,6 @@ def upload_files_api(
return upload_files(files, FileOrigin.OTHER)
@router.get("/admin/connector/{connector_id}/files")
def list_connector_files(
connector_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> ConnectorFilesResponse:
"""List all files in a file connector."""
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector not found")
if connector.source != DocumentSource.FILE:
raise HTTPException(
status_code=400, detail="This endpoint only works with file connectors"
)
file_locations = connector.connector_specific_config.get("file_locations", [])
file_names = connector.connector_specific_config.get("file_names", [])
# Normalize file_names for backwards compatibility with legacy data
file_names = _normalize_file_names_for_backwards_compatibility(
file_locations, file_names
)
file_store = get_default_file_store()
files = []
for file_id, file_name in zip(file_locations, file_names):
try:
file_record = file_store.read_file_record(file_id)
file_size = None
upload_date = None
if file_record:
file_size = file_store.get_file_size(file_id)
upload_date = (
file_record.created_at.isoformat()
if file_record.created_at
else None
)
files.append(
ConnectorFileInfo(
file_id=file_id,
file_name=file_name,
file_size=file_size,
upload_date=upload_date,
)
)
except Exception as e:
logger.warning(f"Error reading file record for {file_id}: {e}")
# Include file with basic info even if record fetch fails
files.append(
ConnectorFileInfo(
file_id=file_id,
file_name=file_name,
)
)
return ConnectorFilesResponse(files=files)
@router.post("/admin/connector/{connector_id}/files/update")
def update_connector_files(
connector_id: int,
files: list[UploadFile] | None = File(None),
file_ids_to_remove: str = Form("[]"),
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
"""
Update files in a connector by adding new files and/or removing existing ones.
This is an atomic operation that validates, updates the connector config, and triggers indexing.
"""
files = files or []
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector not found")
if connector.source != DocumentSource.FILE:
raise HTTPException(
status_code=400, detail="This endpoint only works with file connectors"
)
# Get the connector-credential pair for indexing/pruning triggers
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
# Parse file IDs to remove
try:
file_ids_list = json.loads(file_ids_to_remove)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid file_ids_to_remove format")
if not isinstance(file_ids_list, list):
raise HTTPException(
status_code=400,
detail="file_ids_to_remove must be a JSON-encoded list",
)
# Get current connector config
current_config = connector.connector_specific_config
current_file_locations = current_config.get("file_locations", [])
current_file_names = current_config.get("file_names", [])
current_zip_metadata = current_config.get("zip_metadata", {})
# Upload new files if any
new_file_paths = []
new_file_names_list = []
new_zip_metadata = {}
if files and len(files) > 0:
upload_response = upload_files(files, FileOrigin.CONNECTOR)
new_file_paths = upload_response.file_paths
new_file_names_list = upload_response.file_names
new_zip_metadata = upload_response.zip_metadata
# Remove specified files
files_to_remove_set = set(file_ids_list)
# Normalize file_names for backwards compatibility with legacy data
current_file_names = _normalize_file_names_for_backwards_compatibility(
current_file_locations, current_file_names
)
remaining_file_locations = []
remaining_file_names = []
for file_id, file_name in zip(current_file_locations, current_file_names):
if file_id not in files_to_remove_set:
remaining_file_locations.append(file_id)
remaining_file_names.append(file_name)
# Combine remaining files with new files
final_file_locations = remaining_file_locations + new_file_paths
final_file_names = remaining_file_names + new_file_names_list
# Validate that at least one file remains
if not final_file_locations:
raise HTTPException(
status_code=400,
detail="Cannot remove all files from connector. At least one file must remain.",
)
# Update zip metadata
final_zip_metadata = {
key: value
for key, value in current_zip_metadata.items()
if key not in files_to_remove_set
}
final_zip_metadata.update(new_zip_metadata)
# Update connector config
updated_config = {
**current_config,
"file_locations": final_file_locations,
"file_names": final_file_names,
"zip_metadata": final_zip_metadata,
}
connector_base = ConnectorBase(
name=connector.name,
source=connector.source,
input_type=connector.input_type,
connector_specific_config=updated_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
indexing_start=connector.indexing_start,
)
updated_connector = update_connector(connector_id, connector_base, db_session)
if updated_connector is None:
raise HTTPException(
status_code=500, detail="Failed to update connector configuration"
)
# Trigger re-indexing for new files and pruning for removed files
try:
tenant_id = get_current_tenant_id()
# If files were added, mark for UPDATE indexing (only new docs)
if new_file_paths:
mark_ccpair_with_indexing_trigger(
cc_pair.id, IndexingMode.UPDATE, db_session
)
# Send task to check for indexing immediately
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
)
logger.info(
f"Marked cc_pair {cc_pair.id} for UPDATE indexing (new files) for connector {connector_id}"
)
# If files were removed, trigger pruning immediately
if file_ids_list:
r = get_redis_client()
payload_id = try_creating_prune_generator_task(
client_app, cc_pair, db_session, r, tenant_id
)
if payload_id:
logger.info(
f"Triggered pruning for cc_pair {cc_pair.id} (removed files) for connector "
f"{connector_id}, payload_id={payload_id}"
)
else:
logger.warning(
f"Failed to trigger pruning for cc_pair {cc_pair.id} (removed files) for connector {connector_id}"
)
except Exception as e:
logger.error(f"Failed to trigger re-indexing after file update: {e}")
return FileUploadResponse(
file_paths=final_file_locations,
file_names=final_file_names,
zip_metadata=final_zip_metadata,
)
@router.get("/admin/connector")
def get_connectors_by_credential(
_: User = Depends(current_curator_or_admin_user),
@@ -1156,10 +934,12 @@ def get_connector_indexing_status(
].total_docs_indexed += connector_status.docs_indexed
# Track admin page visit for analytics
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.VISITED_ADMIN_PAGE,
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.VISITED_ADMIN_PAGE,
properties=None,
db_session=db_session,
)
# Group statuses by source for pagination
@@ -1371,10 +1151,12 @@ def create_connector_from_model(
connector_data=connector_base,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.CREATED_CONNECTOR,
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_CONNECTOR,
properties=None,
db_session=db_session,
)
return connector_response
@@ -1450,10 +1232,12 @@ def create_connector_with_mock_credential(
f"cc_pair={response.data}"
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.CREATED_CONNECTOR,
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_CONNECTOR,
properties=None,
db_session=db_session,
)
return response

View File

@@ -571,17 +571,6 @@ class FileUploadResponse(BaseModel):
zip_metadata: dict[str, Any]
class ConnectorFileInfo(BaseModel):
file_id: str
file_name: str
file_size: int | None = None
upload_date: str | None = None
class ConnectorFilesResponse(BaseModel):
files: list[ConnectorFileInfo]
class ObjectCreationIdResponse(BaseModel):
id: int
credential: CredentialSnapshot | None = None

View File

@@ -1,6 +1,5 @@
import asyncio
import base64
import datetime
import hashlib
import json
from collections.abc import Awaitable
@@ -41,7 +40,6 @@ from onyx.db.enums import MCPServerStatus
from onyx.db.enums import MCPTransport
from onyx.db.mcp import create_connection_config
from onyx.db.mcp import create_mcp_server__no_commit
from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
from onyx.db.mcp import delete_connection_config
from onyx.db.mcp import delete_mcp_server
from onyx.db.mcp import delete_user_connection_configs_for_server
@@ -1014,7 +1012,6 @@ def _db_mcp_server_to_api_mcp_server(
is_authenticated=is_authenticated,
user_authenticated=user_authenticated,
status=db_server.status,
last_refreshed_at=db_server.last_refreshed_at,
tool_count=tool_count,
auth_template=auth_template,
user_credentials=user_credentials,
@@ -1136,7 +1133,6 @@ def get_mcp_server_tools_snapshots(
server_id=server_id,
db_session=db,
status=MCPServerStatus.CONNECTED,
last_refreshed_at=datetime.datetime.now(datetime.timezone.utc),
)
db.commit()
except Exception as e:
@@ -1209,7 +1205,7 @@ def _upsert_db_tools(
db_session=db,
passthrough_auth=False,
mcp_server_id=mcp_server_id,
enabled=True,
enabled=False,
)
new_tool.display_name = display_name
new_tool.mcp_input_schema = input_schema
@@ -1387,21 +1383,7 @@ def _upsert_mcp_server(
)
# Cleanup: Delete existing connection configs
# If the auth type is OAUTH, delete all user connection configs
# If the auth type is API_TOKEN, delete the admin connection config and the admin user connection configs
if (
changing_connection_config
and mcp_server.admin_connection_config_id
and request.auth_type == MCPAuthenticationType.OAUTH
):
delete_all_user_connection_configs_for_server_no_commit(
mcp_server.id, db_session
)
elif (
changing_connection_config
and mcp_server.admin_connection_config_id
and request.auth_type == MCPAuthenticationType.API_TOKEN
):
if changing_connection_config and mcp_server.admin_connection_config_id:
delete_connection_config(mcp_server.admin_connection_config_id, db_session)
if user and user.email:
delete_user_connection_configs_for_server(

View File

@@ -1,4 +1,3 @@
import datetime
from enum import Enum
from typing import Any
from typing import List
@@ -298,7 +297,6 @@ class MCPServer(BaseModel):
is_authenticated: bool
user_authenticated: Optional[bool] = None
status: MCPServerStatus
last_refreshed_at: Optional[datetime.datetime] = None
tool_count: int = Field(
default=0, description="Number of tools associated with this server"
)

View File

@@ -59,7 +59,7 @@ from onyx.server.manage.llm.api import get_valid_model_names_for_persona
from onyx.server.models import DisplayPriorityRequest
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -282,10 +282,12 @@ def create_persona(
user=user,
db_session=db_session,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.CREATED_ASSISTANT,
create_milestone_and_report(
user=user,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_ASSISTANT,
properties=None,
db_session=db_session,
)
return persona_snapshot

View File

@@ -8,10 +8,11 @@ from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -123,7 +124,7 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
"""
results = CategorizedFiles()
llm = get_default_llm()
llm, _ = get_default_llms()
tokenizer = get_tokenizer(
model_name=llm.config.model_name, provider_type=llm.config.model_provider
@@ -154,7 +155,7 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
extension = get_file_ext(filename)
# If image, estimate tokens via dedicated method first
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
if extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
try:
token_count = estimate_image_tokens_for_upload(upload)
except (UnidentifiedImageError, OSError) as e:
@@ -172,7 +173,10 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
continue
# Otherwise, handle as text/document: extract text and count tokens
elif extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
if (
extension in ALL_ACCEPTED_FILE_EXTENSIONS
and extension not in ACCEPTED_IMAGE_FILE_EXTENSIONS
):
text_content = extract_file_text(
file=upload.file,
file_name=filename,

View File

@@ -30,7 +30,7 @@ from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.utils import test_llm
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.server.manage.models import BoostDoc
@@ -121,7 +121,7 @@ def validate_existing_genai_api_key(
pass
try:
llm = get_default_llm(timeout=10)
llm, __ = get_default_llms(timeout=10)
except ValueError:
raise HTTPException(status_code=404, detail="LLM not setup")

View File

@@ -1,4 +1,5 @@
import os
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
@@ -31,13 +32,14 @@ from onyx.db.llm import upsert_llm_provider
from onyx.db.llm import validate_persona_ids_exist
from onyx.db.models import User
from onyx.db.persona import user_can_access_persona
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_bedrock_token_limit
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import BedrockFinalModelResponse
@@ -62,6 +64,7 @@ from onyx.server.manage.llm.utils import is_valid_bedrock_model
from onyx.server.manage.llm.utils import ModelMetadata
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
@@ -89,7 +92,7 @@ def test_llm_configuration(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
"""Test LLM configuration settings"""
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
@@ -120,10 +123,36 @@ def test_llm_configuration(
max_input_tokens=max_input_tokens,
)
error_msg = test_llm(llm)
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if (
test_llm_request.fast_default_model_name
and test_llm_request.fast_default_model_name
!= test_llm_request.default_model_name
):
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
deployment_name=test_llm_request.deployment_name,
max_input_tokens=max_input_tokens,
)
functions_with_args.append((test_llm, (fast_llm,)))
if error_msg:
raise HTTPException(status_code=400, detail=error_msg)
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
client_error_msg, _error_code, _is_retryable = litellm_exception_to_error_msg(
error, llm, fallback_to_error_msg=True
)
raise HTTPException(status_code=400, detail=client_error_msg)
@admin_router.post("/test/default")
@@ -131,12 +160,21 @@ def test_default_provider(
_: User | None = Depends(current_admin_user),
) -> None:
try:
llm = get_default_llm()
llm, fast_llm = get_default_llms()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
error = test_llm(llm)
functions_with_args: list[tuple[Callable, tuple]] = [
(test_llm, (llm,)),
(test_llm, (fast_llm,)),
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
raise HTTPException(status_code=400, detail=str(error))
@@ -216,18 +254,34 @@ def put_llm_provider(
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
default_fast_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
if (
llm_provider_upsert_request.fast_default_model_name
and llm_provider_upsert_request.fast_default_model_name
== model_configuration.name
):
model_configuration.is_visible = True
default_fast_model_found = True
default_inserts = set()
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
default_inserts.add(llm_provider_upsert_request.default_model_name)
if (
llm_provider_upsert_request.fast_default_model_name
and not default_fast_model_found
):
default_inserts.add(llm_provider_upsert_request.fast_default_model_name)
llm_provider_upsert_request.model_configurations.extend(
ModelConfigurationUpsertRequest(name=name, is_visible=True)
for name in default_inserts
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed

View File

@@ -32,6 +32,7 @@ class TestLLMRequest(BaseModel):
# model level
default_model_name: str
fast_default_model_name: str | None = None
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
@@ -54,6 +55,7 @@ class LLMProviderDescriptor(BaseModel):
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
fast_default_model_name: str | None
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
@@ -73,6 +75,7 @@ class LLMProviderDescriptor(BaseModel):
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
@@ -90,6 +93,7 @@ class LLMProvider(BaseModel):
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
fast_default_model_name: str | None = None
is_public: bool = True
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
@@ -146,6 +150,7 @@ class LLMProviderView(LLMProvider):
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,

View File

@@ -30,7 +30,7 @@ from onyx.server.manage.validate_tokens import validate_app_token
from onyx.server.manage.validate_tokens import validate_bot_token
from onyx.server.manage.validate_tokens import validate_user_token
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
SLACK_API_CHANNELS_PER_PAGE = 100
@@ -274,10 +274,12 @@ def create_bot(
is_default=True,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=tenant_id,
event=MilestoneRecordType.CREATED_ONYX_BOT,
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_ONYX_BOT,
properties=None,
db_session=db_session,
)
return SlackBot.from_model(slack_bot_model)

View File

@@ -203,17 +203,10 @@ def list_accepted_users(
@router.get("/manage/users/invited")
def list_invited_users(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[InvitedUserSnapshot]:
invited_emails = get_invited_users()
# Filter out users who are already active in the system
active_user_emails = {user.email for user in get_all_users(db_session)}
filtered_invited_emails = [
email for email in invited_emails if email not in active_user_emails
]
return [InvitedUserSnapshot(email=email) for email in filtered_invited_emails]
return [InvitedUserSnapshot(email=email) for email in invited_emails]
@router.get("/manage/users")
@@ -238,13 +231,6 @@ def list_all_users(
accepted_emails = {user.email for user in accepted_users}
slack_users_emails = {user.email for user in slack_users}
invited_emails = get_invited_users()
# Filter out users who are already active (either accepted or slack users)
all_active_emails = accepted_emails | slack_users_emails
invited_emails = [
email for email in invited_emails if email not in all_active_emails
]
if q:
invited_emails = [
email for email in invited_emails if re.search(r"{}".format(q), email, re.I)

View File

@@ -14,10 +14,8 @@ from onyx.db.web_search import deactivate_web_search_provider
from onyx.db.web_search import delete_web_content_provider
from onyx.db.web_search import delete_web_search_provider
from onyx.db.web_search import fetch_web_content_provider_by_name
from onyx.db.web_search import fetch_web_content_provider_by_type
from onyx.db.web_search import fetch_web_content_providers
from onyx.db.web_search import fetch_web_search_provider_by_name
from onyx.db.web_search import fetch_web_search_provider_by_type
from onyx.db.web_search import fetch_web_search_providers
from onyx.db.web_search import set_active_web_content_provider
from onyx.db.web_search import set_active_web_search_provider
@@ -149,33 +147,11 @@ def deactivate_search_provider(
def test_search_provider(
request: WebSearchProviderTestRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
provider_requires_api_key = request.provider_type != WebSearchProviderType.SEARXNG
# Determine which API key to use
api_key = request.api_key
if request.use_stored_key and provider_requires_api_key:
existing_provider = fetch_web_search_provider_by_type(
request.provider_type, db_session
)
if existing_provider is None or not existing_provider.api_key:
raise HTTPException(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key
if provider_requires_api_key and not api_key:
raise HTTPException(
status_code=400,
detail="API key is required. Either provide api_key or set use_stored_key to true.",
)
try:
provider = build_search_provider_from_config(
provider_type=request.provider_type,
api_key=api_key or "",
api_key=request.api_key,
config=request.config or {},
)
except ValueError as exc:
@@ -315,31 +291,11 @@ def deactivate_content_provider(
def test_content_provider(
request: WebContentProviderTestRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
# Determine which API key to use
api_key = request.api_key
if request.use_stored_key:
existing_provider = fetch_web_content_provider_by_type(
request.provider_type, db_session
)
if existing_provider is None or not existing_provider.api_key:
raise HTTPException(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key
if not api_key:
raise HTTPException(
status_code=400,
detail="API key is required. Either provide api_key or set use_stored_key to true.",
)
try:
provider = build_content_provider_from_config(
provider_type=request.provider_type,
api_key=api_key,
api_key=request.api_key,
config=request.config,
)
except ValueError as exc:

View File

@@ -60,25 +60,11 @@ class WebContentProviderUpsertRequest(BaseModel):
class WebSearchProviderTestRequest(BaseModel):
provider_type: WebSearchProviderType
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type instead of api_key.",
)
api_key: str
config: dict[str, Any] | None = None
class WebContentProviderTestRequest(BaseModel):
provider_type: WebContentProviderType
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type instead of api_key.",
)
api_key: str
config: WebContentProviderConfig

View File

@@ -44,6 +44,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.db.feedback import remove_chat_message_feedback
@@ -54,9 +55,9 @@ from onyx.db.projects import check_project_ownership
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.factory import get_llms_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.redis.redis_pool import get_redis_client
from onyx.secondary_llm_flows.chat_session_naming import (
@@ -88,7 +89,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.headers import get_custom_tool_additional_request_headers
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -116,7 +117,7 @@ def _get_available_tokens_for_persona(
- default_reserved_tokens
)
llm = get_llm_for_persona(persona=persona, user=user)
llm, _ = get_llms_for_persona(persona=persona, user=user)
token_counter = get_llm_token_counter(llm)
system_prompt = get_default_base_system_prompt(db_session)
@@ -353,7 +354,7 @@ def rename_chat_session(
chat_session_id=chat_session_id, db_session=db_session
)
llm = get_default_llm(
llm, _ = get_default_llms(
additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
)
@@ -450,11 +451,14 @@ def handle_new_chat_message(
if not chat_message_req.message and not chat_message_req.use_existing_user_message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.RAN_QUERY,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=user,
distinct_id=user.email if user else tenant_id or "N/A",
event_type=MilestoneRecordType.RAN_QUERY,
properties=None,
db_session=db_session,
)
def stream_generator() -> Generator[str, None, None]:
try:
@@ -666,7 +670,7 @@ def seed_chat(
root_message = get_or_create_root_message(
chat_session_id=new_chat_session.id, db_session=db_session
)
llm = get_llm_for_persona(
llm, _fast_llm = get_llms_for_persona(
persona=new_chat_session.persona,
user=user,
)

View File

@@ -1,18 +1,18 @@
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_store.models import ChatFileType
from onyx.utils.file_types import UploadMimeTypes
def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
if mime_type is None:
return ChatFileType.PLAIN_TEXT
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
if mime_type in UploadMimeTypes.IMAGE_MIME_TYPES:
return ChatFileType.IMAGE
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
if mime_type in UploadMimeTypes.CSV_MIME_TYPES:
return ChatFileType.CSV
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
if mime_type in UploadMimeTypes.DOCUMENT_MIME_TYPES:
return ChatFileType.DOC
return ChatFileType.PLAIN_TEXT

View File

@@ -48,8 +48,9 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import find_tags
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.vespa.index import VespaIndex
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import AdminSearchRequest
from onyx.server.query_and_chat.models import AdminSearchResponse
@@ -107,7 +108,7 @@ def handle_search_request(
query = search_request.message
logger.notice(f"Received document search query: {query}")
llm = get_default_llm()
llm, __name__ = get_default_llms()
pagination_limit, pagination_offset = _normalize_pagination(
limit=search_request.retrieval_options.limit,
offset=search_request.retrieval_options.offset,
@@ -228,7 +229,7 @@ def get_answer_stream(
is_for_edit=False,
)
llm = get_llm_for_persona(persona=persona_info, user=user)
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,

View File

@@ -124,14 +124,13 @@ def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packe
def create_image_generation_packets(
images: list[GeneratedImage], turn_index: int, tab_index: int = 0
images: list[GeneratedImage], turn_index: int
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationToolStart(),
)
)
@@ -139,12 +138,11 @@ def create_image_generation_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=ImageGenerationFinal(images=images),
),
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -153,7 +151,6 @@ def create_custom_tool_packets(
tool_name: str,
response_type: str,
turn_index: int,
tab_index: int = 0,
data: dict | list | str | int | float | bool | None = None,
file_ids: list[str] | None = None,
) -> list[Packet]:
@@ -162,7 +159,6 @@ def create_custom_tool_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolStart(tool_name=tool_name),
)
)
@@ -170,7 +166,6 @@ def create_custom_tool_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=CustomToolDelta(
tool_name=tool_name,
response_type=response_type,
@@ -180,7 +175,7 @@ def create_custom_tool_packets(
),
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -189,14 +184,12 @@ def create_fetch_packets(
fetch_docs: list[SavedSearchDoc],
urls: list[str],
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
packets: list[Packet] = []
# Emit start packet
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlStart(),
)
)
@@ -204,7 +197,6 @@ def create_fetch_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlUrls(urls=urls),
)
)
@@ -212,13 +204,12 @@ def create_fetch_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=OpenUrlDocuments(
documents=[SearchDoc(**doc.model_dump()) for doc in fetch_docs]
),
)
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -227,14 +218,12 @@ def create_search_packets(
search_docs: list[SavedSearchDoc],
is_internet_search: bool,
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
packets: list[Packet] = []
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolStart(
is_internet_search=is_internet_search,
),
@@ -246,7 +235,6 @@ def create_search_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolQueriesDelta(queries=search_queries),
),
)
@@ -259,7 +247,6 @@ def create_search_packets(
packets.append(
Packet(
turn_index=turn_index,
tab_index=tab_index,
obj=SearchToolDocumentsDelta(
documents=[
SearchDoc(**doc.model_dump()) for doc in sorted_search_docs
@@ -268,7 +255,7 @@ def create_search_packets(
),
)
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
return packets
@@ -324,7 +311,6 @@ def translate_assistant_message_to_packets(
is_internet_search=tool.in_code_tool_id
== WebSearchTool.__name__,
turn_index=turn_num,
tab_index=tool_call.tab_index,
)
)
@@ -338,12 +324,7 @@ def translate_assistant_message_to_packets(
list[str], tool_call.tool_call_arguments.get("urls", [])
)
packet_list.extend(
create_fetch_packets(
fetch_docs,
urls,
turn_num,
tab_index=tool_call.tab_index,
)
create_fetch_packets(fetch_docs, urls, turn_num)
)
elif tool.in_code_tool_id == ImageGenerationTool.__name__:
@@ -353,9 +334,7 @@ def translate_assistant_message_to_packets(
for img in tool_call.generated_images
]
packet_list.extend(
create_image_generation_packets(
images, turn_num, tab_index=tool_call.tab_index
)
create_image_generation_packets(images, turn_num)
)
else:
@@ -365,7 +344,6 @@ def translate_assistant_message_to_packets(
tool_name=tool.display_name or tool.name,
response_type="text",
turn_index=turn_num,
tab_index=tool_call.tab_index,
data=tool_call.tool_call_response,
)
)

View File

@@ -36,16 +36,15 @@ class StreamingType(Enum):
DEEP_RESEARCH_PLAN_START = "deep_research_plan_start"
DEEP_RESEARCH_PLAN_DELTA = "deep_research_plan_delta"
RESEARCH_AGENT_START = "research_agent_start"
class BaseObj(BaseModel):
type: str = ""
################################################
# Reasoning Packets
################################################
"""Reasoning Packets"""
# Tells the frontend to display the reasoning block
class ReasoningStart(BaseObj):
type: Literal["reasoning_start"] = StreamingType.REASONING_START.value
@@ -62,9 +61,9 @@ class ReasoningDone(BaseObj):
type: Literal["reasoning_done"] = StreamingType.REASONING_DONE.value
################################################
# Final Agent Response Packets
################################################
"""Final Agent Response Packets"""
# Start of the final answer
class AgentResponseStart(BaseObj):
type: Literal["message_start"] = StreamingType.MESSAGE_START.value
@@ -91,9 +90,9 @@ class CitationInfo(BaseObj):
document_id: str
################################################
# Control Packets
################################################
"""Control Packets"""
# This one isn't strictly necessary, remove in the future
class SectionEnd(BaseObj):
type: Literal["section_end"] = "section_end"
@@ -110,9 +109,9 @@ class OverallStop(BaseObj):
type: Literal["stop"] = StreamingType.STOP.value
################################################
# Tool Packets
################################################
"""Tool Packets"""
# Search tool is called and the UI block needs to start
class SearchToolStart(BaseObj):
type: Literal["search_tool_start"] = StreamingType.SEARCH_TOOL_START.value
@@ -226,9 +225,6 @@ class CustomToolDelta(BaseObj):
file_ids: list[str] | None = None
################################################
# Deep Research Packets
################################################
class DeepResearchPlanStart(BaseObj):
type: Literal["deep_research_plan_start"] = (
StreamingType.DEEP_RESEARCH_PLAN_START.value
@@ -243,14 +239,8 @@ class DeepResearchPlanDelta(BaseObj):
content: str
class ResearchAgentStart(BaseObj):
type: Literal["research_agent_start"] = StreamingType.RESEARCH_AGENT_START.value
research_task: str
"""Packet"""
################################################
# Packet Object
################################################
# Discriminated union of all possible packet object types
PacketObj = Union[
# Agent Response Packets
@@ -284,24 +274,15 @@ PacketObj = Union[
# Deep Research Packets
DeepResearchPlanStart,
DeepResearchPlanDelta,
ResearchAgentStart,
]
class Placement(BaseModel):
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
turn_index: int
# For parallel tool calls to preserve order of execution
tab_index: int
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int
class Packet(BaseModel):
turn_index: int | None
# For parallel tool calls to preserve order of execution
tab_index: int = 0
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
sub_turn_index: int | None = None
obj: Annotated[PacketObj, Field(discriminator="type")]
# This is for replaying it back from the DB to the frontend
class EndStepPacketList(BaseModel):
turn_index: int
packet_list: list[Packet]

View File

@@ -15,10 +15,9 @@ class PageType(str, Enum):
class ApplicationStatus(str, Enum):
ACTIVE = "active"
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
ACTIVE = "active"
class Notification(BaseModel):

View File

@@ -11,6 +11,7 @@ from onyx.configs.constants import KV_REINDEX_KEY
from onyx.configs.constants import KV_SEARCH_SETTINGS
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
from onyx.configs.embedding_configs import SupportedEmbeddingModel
from onyx.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from onyx.configs.model_configs import GEN_AI_API_KEY
from onyx.configs.model_configs import GEN_AI_MODEL_VERSION
from onyx.context.search.models import SavedSearchSettings
@@ -294,6 +295,7 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
provider="openai",
@@ -302,6 +304,7 @@ def setup_postgres(db_session: Session) -> None:
api_version=None,
custom_config=None,
default_model_name=llm_model,
fast_default_model_name=fast_model,
is_public=True,
groups=[],
model_configurations=[

View File

@@ -0,0 +1,3 @@
from typing import Any
BUILT_IN_TOOL_MAP_V2: dict[str, Any] = {}

View File

@@ -1,432 +0,0 @@
from collections.abc import Callable
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.deep_research.dr_mock_tools import (
get_research_agent_additional_tool_definitions,
)
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TASK_KEY
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
from onyx.deep_research.models import ResearchAgentCallResult
from onyx.deep_research.utils import check_special_tool_calls
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import ToolChoiceOptions
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.prompts.deep_research.dr_tool_prompts import OPEN_URLS_TOOL_DESCRIPTION
from onyx.prompts.deep_research.dr_tool_prompts import (
OPEN_URLS_TOOL_DESCRIPTION_REASONING,
)
from onyx.prompts.deep_research.dr_tool_prompts import WEB_SEARCH_TOOL_DESCRIPTION
from onyx.prompts.deep_research.research_agent import RESEARCH_AGENT_PROMPT
from onyx.prompts.deep_research.research_agent import RESEARCH_AGENT_PROMPT_REASONING
from onyx.prompts.deep_research.research_agent import RESEARCH_REPORT_PROMPT
from onyx.prompts.deep_research.research_agent import USER_REPORT_QUERY
from onyx.prompts.prompt_utils import get_current_llm_day_time
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import Placement
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_runner import run_tool_calls
from onyx.tools.utils import generate_tools_description
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
RESEARCH_CYCLE_CAP = 3
def generate_intermediate_report(
research_topic: str,
history: list[ChatMessageSimple],
llm: LLM,
token_counter: Callable[[str], int],
user_identity: LLMUserIdentity | None,
state_container: ChatStateContainer,
emitter: Emitter,
placement: Placement,
) -> str:
system_prompt = ChatMessageSimple(
message=RESEARCH_REPORT_PROMPT,
token_count=token_counter(RESEARCH_REPORT_PROMPT),
message_type=MessageType.SYSTEM,
)
reminder_str = USER_REPORT_QUERY.format(research_topic=research_topic)
reminder_message = ChatMessageSimple(
message=reminder_str,
token_count=token_counter(reminder_str),
message_type=MessageType.USER,
)
research_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
project_files=None,
available_tokens=llm.config.max_input_tokens,
)
llm_step_result, _ = run_llm_step(
emitter=emitter,
history=research_history,
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
turn_index=999, # TODO
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
final_documents=None,
user_identity=user_identity,
)
final_report = llm_step_result.answer
if final_report is None:
raise ValueError(
f"LLM failed to generate a report for research task: {research_topic}"
)
# emitter.emit(
# Packet(
# obj=ResearchAgentStart(research_task=research_topic),
# placement=placement,
# )
# )
return final_report
def run_research_agent_call(
research_agent_call: ToolCallKickoff,
tools: list[Tool],
emitter: Emitter,
state_container: ChatStateContainer,
llm: LLM,
is_reasoning_model: bool,
token_counter: Callable[[str], int],
user_identity: LLMUserIdentity | None,
) -> ResearchAgentCallResult:
cycle_count = 0
llm_cycle_count = 0
current_tools = tools
gathered_documents: list[SearchDoc] | None = None
reasoning_cycles = 0
just_ran_web_search = False
turn_index = research_agent_call.turn_index
tab_index = research_agent_call.tab_index
# If this fails to parse, we can't run the loop anyway, let this one fail in that case
research_topic = research_agent_call.tool_args[RESEARCH_AGENT_TASK_KEY]
emitter.emit(
Packet(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=0,
obj=ResearchAgentStart(research_task=research_topic),
)
)
initial_user_message = ChatMessageSimple(
message=research_topic,
token_count=token_counter(research_topic),
message_type=MessageType.USER,
)
msg_history: list[ChatMessageSimple] = [initial_user_message]
citation_mapping: dict[int, str] = {}
while cycle_count <= RESEARCH_CYCLE_CAP:
if cycle_count == RESEARCH_CYCLE_CAP:
current_tools = [
tool
for tool in tools
if tool.name not in {SearchTool.NAME, WebSearchTool.NAME}
]
tools_by_name = {tool.name: tool for tool in current_tools}
tools_description = generate_tools_description(current_tools)
internal_search_tip = (
INTERNAL_SEARCH_GUIDANCE
if any(isinstance(tool, SearchTool) for tool in current_tools)
else ""
)
web_search_tip = (
WEB_SEARCH_TOOL_DESCRIPTION
if any(isinstance(tool, WebSearchTool) for tool in current_tools)
else ""
)
open_urls_tip = (
OPEN_URLS_TOOL_DESCRIPTION
if any(isinstance(tool, OpenURLTool) for tool in current_tools)
else ""
)
if is_reasoning_model and open_urls_tip:
open_urls_tip = OPEN_URLS_TOOL_DESCRIPTION_REASONING
system_prompt_template = (
RESEARCH_AGENT_PROMPT_REASONING
if is_reasoning_model
else RESEARCH_AGENT_PROMPT
)
system_prompt_str = system_prompt_template.format(
available_tools=tools_description,
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=cycle_count,
optional_internal_search_tool_description=internal_search_tip,
optional_web_search_tool_description=web_search_tip,
optional_open_urls_tool_description=open_urls_tip,
)
system_prompt = ChatMessageSimple(
message=system_prompt_str,
token_count=token_counter(system_prompt_str),
message_type=MessageType.SYSTEM,
)
if just_ran_web_search:
reminder_message = ChatMessageSimple(
message=OPEN_URL_REMINDER,
token_count=token_counter(OPEN_URL_REMINDER),
message_type=MessageType.USER,
)
else:
reminder_message = None
constructed_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=msg_history,
reminder_message=reminder_message,
project_files=None,
available_tokens=llm.config.max_input_tokens,
)
research_agent_tools = get_research_agent_additional_tool_definitions(
include_think_tool=not is_reasoning_model
)
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
history=constructed_history,
tool_definitions=[tool.tool_definition() for tool in current_tools]
+ research_agent_tools,
tool_choice=ToolChoiceOptions.REQUIRED,
llm=llm,
turn_index=llm_cycle_count + reasoning_cycles,
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
reasoning_effort=ReasoningEffort.LOW,
final_documents=None,
user_identity=user_identity,
)
if has_reasoned:
reasoning_cycles += 1
tool_responses: list[ToolResponse] = []
tool_calls = llm_step_result.tool_calls or []
just_ran_web_search = False
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
if special_tool_calls.generate_report_tool_call:
final_report = generate_intermediate_report(
research_topic=research_topic,
history=msg_history,
llm=llm,
token_counter=token_counter,
user_identity=user_identity,
state_container=state_container,
emitter=emitter,
placement=Placement(
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
), # TODO
)
return ResearchAgentCallResult(
intermediate_report=final_report, search_docs=[]
)
elif special_tool_calls.think_tool_call:
think_tool_call = special_tool_calls.think_tool_call
tool_call_message = think_tool_call.to_msg_str()
think_tool_msg = ChatMessageSimple(
message=tool_call_message,
token_count=token_counter(tool_call_message),
message_type=MessageType.TOOL_CALL,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
msg_history.append(think_tool_msg)
think_tool_response_msg = ChatMessageSimple(
message=THINK_TOOL_RESPONSE_MESSAGE,
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
msg_history.append(think_tool_response_msg)
reasoning_cycles += 1
continue
else:
tool_responses, citation_mapping = run_tool_calls(
tool_calls=tool_calls,
tools=current_tools,
message_history=msg_history,
memories=None,
user_info=None,
citation_mapping=citation_mapping,
citation_processor=DynamicCitationProcessor(),
# May be better to not do this step, hard to say, needs to be tested
skip_search_query_expansion=False,
)
if tool_calls and not tool_responses:
failure_messages = create_tool_call_failure_messages(
tool_calls[0], token_counter
)
msg_history.extend(failure_messages)
continue
for tool_response in tool_responses:
# Extract tool_call from the response (set by run_tool_calls)
if tool_response.tool_call is None:
raise ValueError("Tool response missing tool_call reference")
tool_call = tool_response.tool_call
tab_index = tool_call.tab_index
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
raise ValueError(
f"Tool '{tool_call.tool_name}' not found in tools list"
)
# Extract search_docs if this is a search tool response
search_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
if gathered_documents:
gathered_documents.extend(search_docs)
else:
gathered_documents = search_docs
# This is used for the Open URL reminder in the next cycle
# only do this if the web search tool yielded results
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
just_ran_web_search = True
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # TODO
turn_index=llm_cycle_count
+ reasoning_cycles, # TODO (subturn index also)
tab_index=tab_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning,
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
generated_images=None,
)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)
# Store tool call with function name and arguments in separate layers
tool_call_message = tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
msg_history.append(tool_call_msg)
tool_response_message = tool_response.llm_facing_response
tool_response_token_count = token_counter(tool_response_message)
tool_response_msg = ChatMessageSimple(
message=tool_response_message,
token_count=tool_response_token_count,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
msg_history.append(tool_response_msg)
llm_cycle_count += 1
# If we've run out of cycles, just try to generate a report from everything so far
final_report = generate_intermediate_report(
research_topic=research_topic,
history=msg_history,
llm=llm,
token_counter=token_counter,
user_identity=user_identity,
state_container=state_container,
emitter=emitter,
placement=Placement(
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
), # TODO
)
return ResearchAgentCallResult(intermediate_report=final_report, search_docs=[])
def run_research_agent_calls(
research_agent_calls: list[ToolCallKickoff],
tools: list[Tool],
emitter: Emitter,
state_container: ChatStateContainer,
llm: LLM,
is_reasoning_model: bool,
token_counter: Callable[[str], int],
user_identity: LLMUserIdentity | None = None,
) -> list[ResearchAgentCallResult]:
# Run all research agent calls in parallel
functions_with_args = [
(
run_research_agent_call,
(
research_agent_call,
tools,
emitter,
state_container,
llm,
is_reasoning_model,
token_counter,
user_identity,
),
)
for research_agent_call in research_agent_calls
]
return run_functions_tuples_in_parallel(
functions_with_args,
allow_failures=True, # Continue even if some research agent calls fail
)

View File

@@ -0,0 +1,27 @@
from typing import Any
from pydantic import BaseModel
from onyx.tools.tool import Tool
class ForceUseTool(BaseModel):
# Could be not a forced usage of the tool but still have args, in which case
# if the tool is called, then those args are applied instead of what the LLM
# wanted to call it with
force_use: bool
tool_name: str
args: dict[str, Any] | None = None
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
return {"type": "function", "name": self.tool_name}
def filter_tools_for_force_tool_use(
tools: list[Tool], force_use_tool: ForceUseTool
) -> list[Tool]:
if not force_use_tool.force_use:
return tools
return [tool for tool in tools if tool.name == force_use_tool.tool_name]

View File

@@ -0,0 +1,46 @@
import json
from typing import Any
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic import BaseModel
from pydantic import ConfigDict
from onyx.natural_language_processing.utils import BaseTokenizer
# Langchain has their own version of pydantic which is version 1
def build_tool_message(
tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]]
) -> ToolMessage:
return ToolMessage(
tool_call_id=tool_call["id"] or "",
name=tool_call["name"],
content=tool_content,
)
class ToolCallSummary(BaseModel):
tool_call_request: AIMessage
tool_call_result: ToolMessage
# This is a workaround to allow arbitrary types in the model
# TODO: Remove this once we have a better solution
model_config = ConfigDict(arbitrary_types_allowed=True)
def tool_call_tokens(
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
) -> int:
request_tokens = len(
llm_tokenizer.encode(
json.dumps(tool_call_summary.tool_call_request.tool_calls[0]["args"])
)
)
result_tokens = len(
llm_tokenizer.encode(json.dumps(tool_call_summary.tool_call_result.content))
)
return request_tokens + result_tokens

Some files were not shown because too many files have changed in this diff Show More