mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 01:05:46 +00:00
Compare commits
2 Commits
jamison/me
...
dump-scrip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca3db17b08 | ||
|
|
ffd13b1104 |
7
.github/workflows/pr-jest-tests.yml
vendored
7
.github/workflows/pr-jest-tests.yml
vendored
@@ -4,14 +4,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
7
.github/workflows/pr-playwright-tests.yml
vendored
7
.github/workflows/pr-playwright-tests.yml
vendored
@@ -4,14 +4,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
@@ -8,6 +8,10 @@ repos:
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
@@ -28,10 +32,6 @@ repos:
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add last refreshed at mcp server
|
||||
|
||||
Revision ID: 2a391f840e85
|
||||
Revises: 4cebcbc9b2ae
|
||||
Create Date: 2025-12-06 15:19:59.766066
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembi.
|
||||
revision = "2a391f840e85"
|
||||
down_revision = "4cebcbc9b2ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("mcp_server", "last_refreshed_at")
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add tab_index to tool_call
|
||||
|
||||
Revision ID: 4cebcbc9b2ae
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4cebcbc9b2ae"
|
||||
down_revision = "a1b2c3d4e5f6"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool_call", "tab_index")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""add license table
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: a01bf2971c5d
|
||||
Create Date: 2025-12-04 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "a01bf2971c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"license",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("license_data", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Singleton pattern - only ever one row in this table
|
||||
op.create_index(
|
||||
"idx_license_singleton",
|
||||
"license",
|
||||
[sa.text("(true)")],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_license_singleton", table_name="license")
|
||||
op.drop_table("license")
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Remove fast_default_model_name from llm_provider
|
||||
|
||||
Revision ID: a2b3c4d5e6f7
|
||||
Revises: 2a391f840e85
|
||||
Create Date: 2024-12-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a2b3c4d5e6f7"
|
||||
down_revision = "2a391f840e85"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("llm_provider", "fast_default_model_name")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Drop milestone table
|
||||
|
||||
Revision ID: b8c9d0e1f2a3
|
||||
Revises: a2b3c4d5e6f7
|
||||
Create Date: 2025-12-18
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b8c9d0e1f2a3"
|
||||
down_revision = "a2b3c4d5e6f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
@@ -1,52 +0,0 @@
|
||||
"""add_deep_research_tool
|
||||
|
||||
Revision ID: c1d2e3f4a5b6
|
||||
Revises: b8c9d0e1f2a3
|
||||
Create Date: 2025-12-18 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c1d2e3f4a5b6"
|
||||
down_revision = "b8c9d0e1f2a3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
|
||||
"""
|
||||
),
|
||||
DEEP_RESEARCH_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -1,278 +0,0 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_license(db_session: Session) -> License | None:
|
||||
"""
|
||||
Get the current license (singleton pattern - only one row).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
License object if exists, None otherwise
|
||||
"""
|
||||
return db_session.execute(select(License)).scalars().first()
|
||||
|
||||
|
||||
def upsert_license(db_session: Session, license_data: str) -> License:
|
||||
"""
|
||||
Insert or update the license (singleton pattern).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
license_data: Base64-encoded signed license blob
|
||||
|
||||
Returns:
|
||||
The created or updated License object
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
|
||||
if existing:
|
||||
existing.license_data = license_data
|
||||
db_session.commit()
|
||||
db_session.refresh(existing)
|
||||
logger.info("License updated")
|
||||
return existing
|
||||
|
||||
new_license = License(license_data=license_data)
|
||||
db_session.add(new_license)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_license)
|
||||
logger.info("License created")
|
||||
return new_license
|
||||
|
||||
|
||||
def delete_license(db_session: Session) -> bool:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if no license existed
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
logger.info("License deleted")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Seat Counting
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Cache Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
def update_license_cache(
|
||||
payload: LicensePayload,
|
||||
source: LicenseSource | None = None,
|
||||
grace_period_end: datetime | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
2. Caching avoids repeated DB + crypto verification on every request
|
||||
3. Status enforcement happens at the feature level, not here
|
||||
|
||||
Args:
|
||||
payload: Verified license payload
|
||||
source: How the license was obtained
|
||||
grace_period_end: Optional grace period end time
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
The cached LicenseMetadata
|
||||
"""
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id=payload.tenant_id,
|
||||
organization_name=payload.organization_name,
|
||||
seats=payload.seats,
|
||||
used_seats=used_seats,
|
||||
plan_type=payload.plan_type,
|
||||
issued_at=payload.issued_at,
|
||||
expires_at=payload.expires_at,
|
||||
grace_period_end=grace_period_end,
|
||||
status=status,
|
||||
source=source,
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.setex(
|
||||
LICENSE_METADATA_KEY,
|
||||
LICENSE_CACHE_TTL_SECONDS,
|
||||
metadata.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
|
||||
return metadata
|
||||
|
||||
|
||||
def refresh_license_cache(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Refresh the license cache from the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
|
||||
license_record = get_license(db_session)
|
||||
if not license_record:
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license during cache refresh: {e}")
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_license_metadata(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata, using cache if available.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
# Try cache first
|
||||
cached = get_cached_license_metadata(tenant_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
@@ -14,7 +14,6 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -140,8 +139,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
include_router_with_global_prefix_prepended(application, usage_export_router)
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseResponse
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/license")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return SeatUsageResponse(
|
||||
total_seats=0,
|
||||
used_seats=0,
|
||||
available_seats=0,
|
||||
)
|
||||
|
||||
return SeatUsageResponse(
|
||||
total_seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
available_seats=max(0, metadata.seats - metadata.used_seats),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
success=True,
|
||||
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Failed to invalidate license cache: {cache_error}")
|
||||
|
||||
deleted = db_delete_license(db_session)
|
||||
|
||||
return {"deleted": deleted}
|
||||
@@ -1,92 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class PlanType(str, Enum):
|
||||
MONTHLY = "monthly"
|
||||
ANNUAL = "annual"
|
||||
|
||||
|
||||
class LicenseSource(str, Enum):
|
||||
AUTO_FETCH = "auto_fetch"
|
||||
MANUAL_UPLOAD = "manual_upload"
|
||||
|
||||
|
||||
class LicensePayload(BaseModel):
|
||||
"""The payload portion of a signed license."""
|
||||
|
||||
version: str
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
seats: int
|
||||
plan_type: PlanType
|
||||
billing_cycle: str | None = None
|
||||
grace_period_days: int = 30
|
||||
stripe_subscription_id: str | None = None
|
||||
stripe_customer_id: str | None = None
|
||||
|
||||
|
||||
class LicenseData(BaseModel):
|
||||
"""Full signed license structure."""
|
||||
|
||||
payload: LicensePayload
|
||||
signature: str
|
||||
|
||||
|
||||
class LicenseMetadata(BaseModel):
|
||||
"""Cached license metadata stored in Redis."""
|
||||
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
seats: int
|
||||
used_seats: int
|
||||
plan_type: PlanType
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus
|
||||
source: LicenseSource | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class LicenseStatusResponse(BaseModel):
|
||||
"""Response for license status API."""
|
||||
|
||||
has_license: bool
|
||||
seats: int = 0
|
||||
used_seats: int = 0
|
||||
plan_type: PlanType | None = None
|
||||
issued_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus | None = None
|
||||
source: LicenseSource | None = None
|
||||
|
||||
|
||||
class LicenseResponse(BaseModel):
|
||||
"""Response after license fetch/upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
license: LicensePayload | None = None
|
||||
|
||||
|
||||
class LicenseUploadResponse(BaseModel):
|
||||
"""Response after license upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class SeatUsageResponse(BaseModel):
|
||||
"""Response for seat usage API."""
|
||||
|
||||
total_seats: int
|
||||
used_seats: int
|
||||
available_seats: int
|
||||
@@ -20,7 +20,7 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -158,7 +158,7 @@ def handle_send_message_simple_with_history(
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
|
||||
@@ -45,7 +45,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRe
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
@@ -269,6 +269,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
@@ -295,6 +296,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
@@ -560,11 +562,17 @@ async def assign_tenant_to_user(
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=email,
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -249,17 +249,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
)
|
||||
raise
|
||||
|
||||
# Remove from invited users list since they've accepted
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
invited_users = get_invited_users()
|
||||
if email in invited_users:
|
||||
invited_users.remove(email)
|
||||
write_invited_users(invited_users)
|
||||
logger.info(f"Removed {email} from invited users list after acceptance")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
"""RSA-4096 license signature verification utilities."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
|
||||
from ee.onyx.server.license.models import LicenseData
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
|
||||
|
||||
def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
"""
|
||||
Verify RSA-4096 signature and return payload if valid.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded JSON containing payload and signature
|
||||
|
||||
Returns:
|
||||
LicensePayload if signature is valid
|
||||
|
||||
Raises:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
def get_license_status(
|
||||
payload: LicensePayload,
|
||||
grace_period_end: datetime | None = None,
|
||||
) -> ApplicationStatus:
|
||||
"""
|
||||
Determine current license status based on expiry.
|
||||
|
||||
Args:
|
||||
payload: The verified license payload
|
||||
grace_period_end: Optional grace period end datetime
|
||||
|
||||
Returns:
|
||||
ApplicationStatus indicating current license state
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if grace period has expired
|
||||
if grace_period_end and now > grace_period_end:
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# Check if license has expired
|
||||
if now > payload.expires_at:
|
||||
if grace_period_end and now <= grace_period_end:
|
||||
return ApplicationStatus.GRACE_PERIOD
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# License is valid
|
||||
return ApplicationStatus.ACTIVE
|
||||
|
||||
|
||||
def is_license_valid(payload: LicensePayload) -> bool:
|
||||
"""Check if a license is currently valid (not expired)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now <= payload.expires_at
|
||||
@@ -117,7 +117,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -653,11 +653,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
event=MilestoneRecordType.USER_SIGNED_UP,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
event_type = (
|
||||
MilestoneRecordType.USER_SIGNED_UP
|
||||
if user_count == 1
|
||||
else MilestoneRecordType.MULTIPLE_USERS
|
||||
)
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=event_type,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -45,7 +45,6 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -109,7 +108,6 @@ from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -549,12 +547,6 @@ def check_indexing_completion(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
|
||||
@@ -127,6 +127,12 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -20,6 +21,7 @@ from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -30,8 +32,11 @@ from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
@@ -44,16 +49,34 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -249,6 +272,583 @@ def _check_failure_threshold(
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >2 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
)
|
||||
|
||||
if index_attempt_start.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
is_primary = (
|
||||
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
from_beginning = index_attempt_start.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=is_primary,
|
||||
should_fetch_permissions_during_indexing=(
|
||||
index_attempt_start.connector_credential_pair.access_type
|
||||
== AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=ctx.should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint, _ = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
# save the initial checkpoint to have a proper record of the
|
||||
# "last used checkpoint"
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# NOTE: this progress callback runs on every loop. We've seen cases
|
||||
# where we loop many times with no new documents and eventually time
|
||||
# out, so only doing the callback after indexing isn't sufficient.
|
||||
callback.progress("_run_indexing", 0)
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
ctx.search_settings_status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
batch_description = []
|
||||
|
||||
# Generate an ID that can be used to correlate activity between here
|
||||
# and the embedding model server
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
|
||||
index_attempt_md.structured_id = (
|
||||
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
|
||||
)
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
db_session=db_session,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
request_id=index_attempt_md.request_id,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += index_pipeline_result.new_docs
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"current_docs_indexed": document_count,
|
||||
"current_chunks_indexed": chunk_count,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
# we know index attempt is successful (at least partially) at this point,
|
||||
# all other cases have been short-circuited
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"failures={total_failures} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=window_end,
|
||||
)
|
||||
if ctx.should_fetch_permissions_during_indexing:
|
||||
mark_cc_pair_as_permissions_synced(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
start_time=window_end,
|
||||
)
|
||||
|
||||
|
||||
def run_docfetching_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
|
||||
64
backend/onyx/chat/chat_milestones.py
Normal file
64
backend/onyx/chat/chat_milestones.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Module for handling chat-related milestone tracking and telemetry.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
|
||||
|
||||
def process_multi_assistant_milestone(
|
||||
user: User | None,
|
||||
assistant_id: int,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Process the multi-assistant milestone for a user.
|
||||
|
||||
This function:
|
||||
1. Creates or retrieves the multi-assistant milestone
|
||||
2. Updates the milestone with the current assistant usage
|
||||
3. Checks if the milestone was just achieved
|
||||
4. Sends telemetry if the milestone was just hit
|
||||
|
||||
Args:
|
||||
user: The user for whom to process the milestone (can be None for anonymous users)
|
||||
assistant_id: The ID of the assistant being used
|
||||
tenant_id: The current tenant ID
|
||||
db_session: Database session for queries
|
||||
"""
|
||||
# Create or retrieve the multi-assistant milestone
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update the milestone with the current assistant usage
|
||||
update_user_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
||||
assistant_id=assistant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Check if the milestone was just achieved
|
||||
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Send telemetry if the milestone was just hit
|
||||
if just_hit_multi_assistant_milestone:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
properties=None,
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from queue import Empty
|
||||
@@ -19,14 +18,9 @@ class ChatStateContainer:
|
||||
|
||||
This container holds the partial state that can be saved to the database
|
||||
if the generation is stopped by the user or completes normally.
|
||||
|
||||
Thread-safe: All write operations are protected by a lock to ensure safe
|
||||
concurrent access from multiple threads. For thread-safe reads, use the
|
||||
getter methods. Direct attribute access is not thread-safe.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self.tool_calls: list[ToolCallInfo] = []
|
||||
self.reasoning_tokens: str | None = None
|
||||
self.answer_tokens: str | None = None
|
||||
@@ -37,53 +31,23 @@ class ChatStateContainer:
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
with self._lock:
|
||||
self.tool_calls.append(tool_call)
|
||||
self.tool_calls.append(tool_call)
|
||||
|
||||
def set_reasoning_tokens(self, reasoning: str | None) -> None:
|
||||
"""Set the reasoning tokens from the final answer generation."""
|
||||
with self._lock:
|
||||
self.reasoning_tokens = reasoning
|
||||
self.reasoning_tokens = reasoning
|
||||
|
||||
def set_answer_tokens(self, answer: str | None) -> None:
|
||||
"""Set the answer tokens from the final answer generation."""
|
||||
with self._lock:
|
||||
self.answer_tokens = answer
|
||||
self.answer_tokens = answer
|
||||
|
||||
def set_citation_mapping(self, citation_to_doc: dict[int, Any]) -> None:
|
||||
"""Set the citation mapping from citation processor."""
|
||||
with self._lock:
|
||||
self.citation_to_doc = citation_to_doc
|
||||
self.citation_to_doc = citation_to_doc
|
||||
|
||||
def set_is_clarification(self, is_clarification: bool) -> None:
|
||||
"""Set whether this turn is a clarification question."""
|
||||
with self._lock:
|
||||
self.is_clarification = is_clarification
|
||||
|
||||
def get_answer_tokens(self) -> str | None:
|
||||
"""Thread-safe getter for answer_tokens."""
|
||||
with self._lock:
|
||||
return self.answer_tokens
|
||||
|
||||
def get_reasoning_tokens(self) -> str | None:
|
||||
"""Thread-safe getter for reasoning_tokens."""
|
||||
with self._lock:
|
||||
return self.reasoning_tokens
|
||||
|
||||
def get_tool_calls(self) -> list[ToolCallInfo]:
|
||||
"""Thread-safe getter for tool_calls (returns a copy)."""
|
||||
with self._lock:
|
||||
return self.tool_calls.copy()
|
||||
|
||||
def get_citation_to_doc(self) -> dict[int, SearchDoc]:
|
||||
"""Thread-safe getter for citation_to_doc (returns a copy)."""
|
||||
with self._lock:
|
||||
return self.citation_to_doc.copy()
|
||||
|
||||
def get_is_clarification(self) -> bool:
|
||||
"""Thread-safe getter for is_clarification."""
|
||||
with self._lock:
|
||||
return self.is_clarification
|
||||
self.is_clarification = is_clarification
|
||||
|
||||
|
||||
def run_chat_llm_with_state_containers(
|
||||
|
||||
@@ -49,10 +49,8 @@ from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
@@ -731,38 +729,3 @@ def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) ->
|
||||
if message.message_type == MessageType.ASSISTANT:
|
||||
return message.is_clarification
|
||||
return False
|
||||
|
||||
|
||||
def create_tool_call_failure_messages(
|
||||
tool_call: ToolCallKickoff, token_counter: Callable[[str], int]
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Create ChatMessageSimple objects for a failed tool call.
|
||||
|
||||
Creates two messages:
|
||||
1. The tool call message itself
|
||||
2. A failure response message indicating the tool call failed
|
||||
|
||||
Args:
|
||||
tool_call: The ToolCallKickoff object representing the failed tool call
|
||||
token_counter: Function to count tokens in a message string
|
||||
|
||||
Returns:
|
||||
List containing two ChatMessageSimple objects: tool call message and failure response
|
||||
"""
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call.to_msg_str(),
|
||||
token_count=token_counter(tool_call.to_msg_str()),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
|
||||
failure_response_msg = ChatMessageSimple(
|
||||
message=TOOL_CALL_FAILURE_PROMPT,
|
||||
token_count=token_counter(TOOL_CALL_FAILURE_PROMPT),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
|
||||
return [tool_call_msg, failure_response_msg]
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_ARGUMENTS
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_FUNC_NAME
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import LlmStepResult
|
||||
@@ -322,6 +325,7 @@ def run_llm_loop(
|
||||
# Pass the total budget to construct_message_history, which will handle token allocation
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO
|
||||
collected_tool_calls: list[ToolCallInfo] = []
|
||||
# Initialize gathered_documents with project files if present
|
||||
gathered_documents: list[SearchDoc] | None = (
|
||||
list(project_citation_mapping.values())
|
||||
@@ -339,8 +343,12 @@ def run_llm_loop(
|
||||
has_called_search_tool: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
reasoning_cycles = 0
|
||||
current_tool_call_index = (
|
||||
0 # TODO: just use the cycle count after parallel tool calls are supported
|
||||
)
|
||||
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
|
||||
@@ -437,13 +445,12 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
step_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
turn_index=current_tool_call_index,
|
||||
citation_processor=citation_processor,
|
||||
state_container=state_container,
|
||||
# The rich docs representation is passed in so that when yielding the answer, it can also
|
||||
@@ -452,8 +459,18 @@ def run_llm_loop(
|
||||
final_documents=gathered_documents,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, current_tool_call_index = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
@@ -463,39 +480,20 @@ def run_llm_loop(
|
||||
tool_responses: list[ToolResponse] = []
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
# Quick note for why citation_mapping and citation_processors are both needed:
|
||||
# 1. Tools return lightweight string mappings, not SearchDoc objects
|
||||
# 2. The SearchDoc resolution is deliberately deferred to llm_loop.py
|
||||
# 3. The citation_processor operates on SearchDoc objects and can't provide a complete reverse URL lookup for
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=citation_processor,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
failure_messages = create_tool_call_failure_messages(
|
||||
tool_calls[0], token_counter
|
||||
for tool_call in tool_calls:
|
||||
# TODO replace the [tool_call] with the list of tool calls once parallel tool calls are supported
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=[tool_call],
|
||||
tools=final_tools,
|
||||
turn_index=current_tool_call_index,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=citation_processor,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
simple_chat_history.extend(failure_messages)
|
||||
continue
|
||||
|
||||
for tool_response in tool_responses:
|
||||
# Extract tool_call from the response (set by run_tool_calls)
|
||||
if tool_response.tool_call is None:
|
||||
raise ValueError("Tool response missing tool_call reference")
|
||||
|
||||
tool_call = tool_response.tool_call
|
||||
tab_index = tool_call.tab_index
|
||||
|
||||
# Track if search tool was called (for skipping query expansion on subsequent calls)
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
@@ -504,103 +502,110 @@ def run_llm_loop(
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
# Add the results to the chat history. Even though tools may run in parallel,
|
||||
# LLM APIs require linear history, so results are added sequentially.
|
||||
# Get the tool object to retrieve tool_id
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
# Add the results to the chat history, note that even if the tools were run in parallel, this isn't supported
|
||||
# as all the LLM APIs require linear history, so these will just be included sequentially
|
||||
for tool_call, tool_response in zip([tool_call], tool_responses):
|
||||
# Get the tool object to retrieve tool_id
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
gathered_documents = search_docs
|
||||
|
||||
# This is used for the Open URL reminder in the next cycle
|
||||
# only do this if the web search tool yielded results
|
||||
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
|
||||
just_ran_web_search = True
|
||||
|
||||
# Extract generated_images if this is an image generation tool response
|
||||
generated_images = None
|
||||
if isinstance(
|
||||
tool_response.rich_response, FinalImageGenerationResponse
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# Store tool call with function name and arguments in separate layers
|
||||
tool_call_message = tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_response_message = tool_response.llm_facing_response
|
||||
tool_response_token_count = token_counter(tool_response_message)
|
||||
|
||||
tool_response_msg = ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=tool_response_token_count,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_response_msg)
|
||||
|
||||
# Update citation processor if this was a search tool
|
||||
if tool_call.tool_name in citeable_tools_names:
|
||||
# Check if the rich_response is a SearchDocsResponse
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_response = tool_response.rich_response
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
gathered_documents = search_docs
|
||||
|
||||
# Create mapping from citation number to SearchDoc
|
||||
citation_to_doc: dict[int, SearchDoc] = {}
|
||||
for (
|
||||
citation_num,
|
||||
doc_id,
|
||||
) in search_response.citation_mapping.items():
|
||||
# Find the SearchDoc with this doc_id
|
||||
matching_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in search_response.search_docs
|
||||
if doc.document_id == doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
citation_to_doc[citation_num] = matching_doc
|
||||
# This is used for the Open URL reminder in the next cycle
|
||||
# only do this if the web search tool yielded results
|
||||
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
|
||||
just_ran_web_search = True
|
||||
|
||||
# Update the citation processor
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
# Extract generated_images if this is an image generation tool response
|
||||
generated_images = None
|
||||
if isinstance(
|
||||
tool_response.rich_response, FinalImageGenerationResponse
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=current_tool_call_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
collected_tool_calls.append(tool_call_info)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# Store tool call with function name and arguments in separate layers
|
||||
tool_call_data = {
|
||||
TOOL_CALL_MSG_FUNC_NAME: tool_call.tool_name,
|
||||
TOOL_CALL_MSG_ARGUMENTS: tool_call.tool_args,
|
||||
}
|
||||
tool_call_message = json.dumps(tool_call_data)
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_response_message = tool_response.llm_facing_response
|
||||
tool_response_token_count = token_counter(tool_response_message)
|
||||
|
||||
tool_response_msg = ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=tool_response_token_count,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_response_msg)
|
||||
|
||||
# Update citation processor if this was a search tool
|
||||
if tool_call.tool_name in citeable_tools_names:
|
||||
# Check if the rich_response is a SearchDocsResponse
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_response = tool_response.rich_response
|
||||
|
||||
# Create mapping from citation number to SearchDoc
|
||||
citation_to_doc: dict[int, SearchDoc] = {}
|
||||
for (
|
||||
citation_num,
|
||||
doc_id,
|
||||
) in search_response.citation_mapping.items():
|
||||
# Find the SearchDoc with this doc_id
|
||||
matching_doc = next(
|
||||
(
|
||||
doc
|
||||
for doc in search_response.search_docs
|
||||
if doc.document_id == doc_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_doc:
|
||||
citation_to_doc[citation_num] = matching_doc
|
||||
|
||||
# Update the citation processor
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
|
||||
current_tool_call_index += 1
|
||||
|
||||
# If no tool calls, then it must have answered, wrap up
|
||||
if not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0:
|
||||
@@ -624,8 +629,5 @@ def run_llm_loop(
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
Packet(turn_index=current_tool_call_index, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
@@ -1,16 +1,10 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.chat.emitter import Emitter
|
||||
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
@@ -23,7 +17,6 @@ from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import Delta
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import FunctionCall
|
||||
@@ -41,8 +34,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import TOOL_CALL_MSG_ARGUMENTS
|
||||
from onyx.tools.models import TOOL_CALL_MSG_FUNC_NAME
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
@@ -52,43 +43,8 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
|
||||
"""Parse tool arguments into a dict.
|
||||
|
||||
Normal case:
|
||||
- raw_args == '{"queries":[...]}' -> dict via json.loads
|
||||
|
||||
Defensive case (JSON string literal of an object):
|
||||
- raw_args == '"{\\"queries\\":[...]}"' -> json.loads -> str -> json.loads -> dict
|
||||
|
||||
Anything else returns {}.
|
||||
"""
|
||||
|
||||
if raw_args is None:
|
||||
return {}
|
||||
|
||||
if isinstance(raw_args, dict):
|
||||
return raw_args
|
||||
|
||||
if not isinstance(raw_args, str):
|
||||
return {}
|
||||
|
||||
try:
|
||||
parsed1: Any = json.loads(raw_args)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
|
||||
if isinstance(parsed1, dict):
|
||||
return parsed1
|
||||
|
||||
if isinstance(parsed1, str):
|
||||
try:
|
||||
parsed2: Any = json.loads(parsed1)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed2 if isinstance(parsed2, dict) else {}
|
||||
|
||||
return {}
|
||||
TOOL_CALL_MSG_FUNC_NAME = "function_name"
|
||||
TOOL_CALL_MSG_ARGUMENTS = "arguments"
|
||||
|
||||
|
||||
def _format_message_history_for_logging(
|
||||
@@ -197,19 +153,21 @@ def _update_tool_call_with_delta(
|
||||
|
||||
def _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]],
|
||||
turn_index: int,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract ToolCallKickoff objects from the tool call map.
|
||||
|
||||
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
|
||||
Each tool call is assigned the given turn_index and a tab_index based on its order.
|
||||
"""
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
try:
|
||||
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
|
||||
# Parse arguments JSON string to dict
|
||||
tool_args = (
|
||||
json.loads(tool_call_data["arguments"])
|
||||
if tool_call_data["arguments"]
|
||||
else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
@@ -222,11 +180,8 @@ def _extract_tool_call_kickoffs(
|
||||
tool_call_id=tool_call_data["id"],
|
||||
tool_name=tool_call_data["name"],
|
||||
tool_args=tool_args,
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
return tool_calls
|
||||
|
||||
|
||||
@@ -317,19 +272,13 @@ def translate_history_to_llm_format(
|
||||
function_name = tool_call_data.get(
|
||||
TOOL_CALL_MSG_FUNC_NAME, "unknown"
|
||||
)
|
||||
raw_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
else:
|
||||
function_name = "unknown"
|
||||
raw_args = (
|
||||
tool_args = (
|
||||
tool_call_data if isinstance(tool_call_data, dict) else {}
|
||||
)
|
||||
|
||||
# IMPORTANT: `FunctionCall.arguments` must be a JSON object string.
|
||||
# If `raw_args` is accidentally a JSON string literal of an object
|
||||
# (e.g. '"{\\"queries\\":[...]}"'), calling `json.dumps(raw_args)`
|
||||
# would produce a quoted JSON literal and break Anthropic tool parsing.
|
||||
tool_args = _parse_tool_args_to_dict(raw_args)
|
||||
|
||||
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
|
||||
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
|
||||
tool_call = ToolCall(
|
||||
@@ -375,25 +324,20 @@ def translate_history_to_llm_format(
|
||||
return messages
|
||||
|
||||
|
||||
def run_llm_step_pkt_generator(
|
||||
def run_llm_step(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
turn_index: int,
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
state_container: ChatStateContainer,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
custom_token_processor: (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
|
||||
) = None,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
has_reasoned = 0
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
@@ -407,8 +351,6 @@ def run_llm_step_pkt_generator(
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
with generation_span(
|
||||
model=llm.config.model_name,
|
||||
model_config={
|
||||
@@ -424,7 +366,7 @@ def run_llm_step_pkt_generator(
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=None, # TODO
|
||||
reasoning_effort=reasoning_effort,
|
||||
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
|
||||
user_identity=user_identity,
|
||||
):
|
||||
if packet.usage:
|
||||
@@ -437,17 +379,6 @@ def run_llm_step_pkt_generator(
|
||||
}
|
||||
delta = packet.choice.delta
|
||||
|
||||
if custom_token_processor:
|
||||
# The custom token processor can modify the deltas for specific custom logic
|
||||
# It can also return a state so that it can handle aggregated delta logic etc.
|
||||
# Loosely typed so the function can be flexible
|
||||
modified_delta, processor_state = custom_token_processor(
|
||||
delta, processor_state
|
||||
)
|
||||
if modified_delta is None:
|
||||
continue
|
||||
delta = modified_delta
|
||||
|
||||
# Should only happen once, frontend does not expect multiple
|
||||
# ReasoningStart or ReasoningDone packets.
|
||||
if delta.reasoning_content:
|
||||
@@ -471,42 +402,32 @@ def run_llm_step_pkt_generator(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=result,
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
@@ -514,22 +435,13 @@ def run_llm_step_pkt_generator(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
if flush_delta and flush_delta.tool_calls:
|
||||
for tool_call_delta in flush_delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map, turn_index + has_reasoned
|
||||
)
|
||||
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
|
||||
if tool_calls:
|
||||
tool_calls_list: list[ToolCall] = [
|
||||
ToolCall(
|
||||
@@ -562,7 +474,7 @@ def run_llm_step_pkt_generator(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index += 1
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
if citation_processor:
|
||||
@@ -572,12 +484,12 @@ def run_llm_step_pkt_generator(
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index + has_reasoned,
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
|
||||
@@ -602,49 +514,5 @@ def run_llm_step_pkt_generator(
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
bool(has_reasoned),
|
||||
turn_index,
|
||||
)
|
||||
|
||||
|
||||
def run_llm_step(
|
||||
emitter: "Emitter",
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
turn_index: int,
|
||||
state_container: ChatStateContainer,
|
||||
citation_processor: DynamicCitationProcessor | None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
custom_token_processor: (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
|
||||
) = None,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Wrapper around run_llm_step_pkt_generator that consumes packets and emits them.
|
||||
|
||||
Returns:
|
||||
tuple[LlmStepResult, bool]: The LLM step result and whether reasoning occurred.
|
||||
"""
|
||||
step_generator = run_llm_step_pkt_generator(
|
||||
history=history,
|
||||
tool_definitions=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
turn_index=turn_index,
|
||||
state_container=state_container,
|
||||
citation_processor=citation_processor,
|
||||
reasoning_effort=reasoning_effort,
|
||||
final_documents=final_documents,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_token_processor,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, has_reasoned = e.value
|
||||
return llm_step_result, bool(has_reasoned)
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_milestones import process_multi_assistant_milestone
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_llm_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
@@ -31,7 +32,6 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
@@ -51,8 +51,8 @@ from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
@@ -72,7 +72,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -368,10 +367,11 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Milestone tracking, most devs using the API don't need to understand this
|
||||
mt_cloud_telemetry(
|
||||
process_multi_assistant_milestone(
|
||||
user=user,
|
||||
assistant_id=persona.id,
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
@@ -379,7 +379,7 @@ def stream_chat_message_objects(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
@@ -475,6 +475,7 @@ def stream_chat_message_objects(
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=user_selected_filters,
|
||||
project_id=(
|
||||
|
||||
@@ -102,7 +102,6 @@ def _create_and_link_tool_calls(
|
||||
if tool_call_info.generated_images
|
||||
else None
|
||||
),
|
||||
tab_index=tool_call_info.tab_index,
|
||||
add_only=True,
|
||||
)
|
||||
|
||||
@@ -220,8 +219,8 @@ def save_chat_turn(
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
|
||||
set(search_doc_ids_for_tool)
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = (
|
||||
search_doc_ids_for_tool
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
|
||||
@@ -332,6 +332,7 @@ class FileType(str, Enum):
|
||||
class MilestoneRecordType(str, Enum):
|
||||
TENANT_CREATED = "tenant_created"
|
||||
USER_SIGNED_UP = "user_signed_up"
|
||||
MULTIPLE_USERS = "multiple_users"
|
||||
VISITED_ADMIN_PAGE = "visited_admin_page"
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
|
||||
@@ -51,9 +51,10 @@ CROSS_ENCODER_RANGE_MIN = 0
|
||||
# Generative AI Model Configs
|
||||
#####
|
||||
|
||||
# NOTE: the 2 below should only be used for dev.
|
||||
# NOTE: the 3 below should only be used for dev.
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
|
||||
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
|
||||
|
||||
# Override the auto-detection of LLM max context length
|
||||
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None
|
||||
|
||||
@@ -40,7 +40,8 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -409,7 +410,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
|
||||
# Handle image files
|
||||
if file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
|
||||
if not self._allow_images:
|
||||
logger.debug(
|
||||
f"Skipping image file: {key} (image processing not enabled)"
|
||||
|
||||
@@ -84,12 +84,6 @@ ONE_DAY = ONE_HOUR * 24
|
||||
MAX_CACHED_IDS = 100
|
||||
|
||||
|
||||
def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str:
|
||||
if allow_missing and "id" not in page:
|
||||
return "unknown"
|
||||
return str(page["id"])
|
||||
|
||||
|
||||
class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
next_page_url: str | None
|
||||
@@ -305,7 +299,7 @@ class ConfluenceConnector(
|
||||
page_id = page_url = ""
|
||||
try:
|
||||
# Extract basic page information
|
||||
page_id = _get_page_id(page)
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
logger.info(f"Converting page {page_title} to document")
|
||||
page_url = build_confluence_document_id(
|
||||
@@ -388,9 +382,7 @@ class ConfluenceConnector(
|
||||
this function. The returned documents/connectorfailures are for non-inline attachments
|
||||
and those at the end of the page.
|
||||
"""
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
attachment_query = self._construct_attachment_query(page["id"], start, end)
|
||||
attachment_failures: list[ConnectorFailure] = []
|
||||
attachment_docs: list[Document] = []
|
||||
page_url = ""
|
||||
@@ -438,7 +430,7 @@ class ConfluenceConnector(
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=_get_page_id(page),
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
@@ -523,21 +515,14 @@ class ConfluenceConnector(
|
||||
except HTTPError as e:
|
||||
# If we get a 403 after all retries, the user likely doesn't have permission
|
||||
# to access attachments on this page. Log and skip rather than failing the whole job.
|
||||
page_id = _get_page_id(page, allow_missing=True)
|
||||
page_title = page.get("title", "unknown")
|
||||
if e.response and e.response.status_code in [401, 403]:
|
||||
failure_message_prefix = (
|
||||
"Invalid credentials (401)"
|
||||
if e.response.status_code == 401
|
||||
else "Permission denied (403)"
|
||||
)
|
||||
failure_message = (
|
||||
f"{failure_message_prefix} when fetching attachments for page '{page_title}' "
|
||||
if e.response and e.response.status_code == 403:
|
||||
page_title = page.get("title", "unknown")
|
||||
page_id = page.get("id", "unknown")
|
||||
logger.warning(
|
||||
f"Permission denied (403) when fetching attachments for page '{page_title}' "
|
||||
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
|
||||
"Skipping attachments for this page."
|
||||
)
|
||||
logger.warning(failure_message)
|
||||
|
||||
# Build the page URL for the failure record
|
||||
try:
|
||||
page_url = build_confluence_document_id(
|
||||
@@ -552,7 +537,7 @@ class ConfluenceConnector(
|
||||
document_id=page_id,
|
||||
document_link=page_url,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
|
||||
exception=e,
|
||||
)
|
||||
]
|
||||
@@ -723,7 +708,7 @@ class ConfluenceConnector(
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
page_id = _get_page_id(page)
|
||||
page_id = page["id"]
|
||||
page_restrictions = page.get("restrictions") or {}
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
@@ -743,7 +728,7 @@ class ConfluenceConnector(
|
||||
)
|
||||
|
||||
# Query attachments for each page
|
||||
attachment_query = self._construct_attachment_query(_get_page_id(page))
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
|
||||
@@ -24,9 +24,9 @@ from onyx.configs.app_configs import (
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -56,13 +56,15 @@ def validate_attachment_filetype(
|
||||
"""
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
if media_type.startswith("image/"):
|
||||
return media_type in OnyxMimeTypes.IMAGE_MIME_TYPES
|
||||
return is_valid_image_type(media_type)
|
||||
|
||||
# For non-image files, check if we support the extension
|
||||
title = attachment.get("title", "")
|
||||
extension = get_file_ext(title)
|
||||
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
|
||||
|
||||
return extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS
|
||||
return is_accepted_file_ext(
|
||||
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
)
|
||||
|
||||
|
||||
class AttachmentProcessingResult(BaseModel):
|
||||
|
||||
@@ -28,8 +28,10 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
@@ -68,15 +70,14 @@ def _process_egnyte_file(
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# Explicitly excluding image extensions here. TODO: consider allowing images
|
||||
if extension not in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
if not is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
# TODO @wenxi-onyx: convert to extract_text_and_images
|
||||
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_onyx_metadata=False
|
||||
|
||||
@@ -18,7 +18,8 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -89,7 +90,7 @@ def _process_file(
|
||||
# Get file extension and determine file type
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
@@ -110,7 +111,7 @@ def _process_file(
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
if extension in LoadConnector.IMAGE_EXTENSIONS:
|
||||
# Read the image data
|
||||
image_data = file.read()
|
||||
if not image_data:
|
||||
|
||||
@@ -29,14 +29,14 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -114,6 +114,14 @@ def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
return urlunparse(parsed_url)
|
||||
|
||||
|
||||
def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Return True if the mime_type is a common image type in GDrive.
|
||||
(e.g. 'image/png', 'image/jpeg')
|
||||
"""
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def download_request(
|
||||
service: GoogleDriveService, file_id: str, size_threshold: int
|
||||
) -> bytes:
|
||||
@@ -165,7 +173,7 @@ def _download_and_extract_sections_basic(
|
||||
def response_call() -> bytes:
|
||||
return download_request(service, file_id, size_threshold)
|
||||
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
# Skip images if not explicitly enabled
|
||||
if not allow_images:
|
||||
return []
|
||||
@@ -252,7 +260,7 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# Final attempt at extracting text
|
||||
file_ext = get_file_ext(file.get("name", ""))
|
||||
if file_ext not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
|
||||
logger.warning(f"Skipping file {file.get('name')} due to extension.")
|
||||
return []
|
||||
|
||||
|
||||
@@ -23,8 +23,9 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -308,7 +309,10 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS
|
||||
and (
|
||||
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
)
|
||||
and can_download
|
||||
):
|
||||
content_response = self.client.get_item_content(item_id)
|
||||
|
||||
@@ -27,6 +27,8 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
class BaseConnector(abc.ABC, Generic[CT]):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
@@ -10,7 +9,6 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
from jira.resources import Issue
|
||||
from more_itertools import chunked
|
||||
from typing_extensions import override
|
||||
@@ -136,80 +134,6 @@ def _perform_jql_search(
|
||||
return _perform_jql_search_v2(jira_client, jql, start, max_results, fields)
|
||||
|
||||
|
||||
def _handle_jira_search_error(e: Exception, jql: str) -> None:
|
||||
"""Handle common Jira search errors and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
e: The exception raised by the Jira API
|
||||
jql: The JQL query that caused the error
|
||||
|
||||
Raises:
|
||||
ConnectorValidationError: For HTTP 400 errors (invalid JQL or project)
|
||||
CredentialExpiredError: For HTTP 401 errors
|
||||
InsufficientPermissionsError: For HTTP 403 errors
|
||||
Exception: Re-raises the original exception for other error types
|
||||
"""
|
||||
# Extract error information from the exception
|
||||
error_text = ""
|
||||
status_code = None
|
||||
|
||||
def _format_error_text(error_payload: Any) -> str:
|
||||
error_messages = (
|
||||
error_payload.get("errorMessages", [])
|
||||
if isinstance(error_payload, dict)
|
||||
else []
|
||||
)
|
||||
if error_messages:
|
||||
return (
|
||||
"; ".join(error_messages)
|
||||
if isinstance(error_messages, list)
|
||||
else str(error_messages)
|
||||
)
|
||||
return str(error_payload)
|
||||
|
||||
# Try to get status code and error text from JIRAError or requests response
|
||||
if hasattr(e, "status_code"):
|
||||
status_code = e.status_code
|
||||
raw_text = getattr(e, "text", "")
|
||||
if isinstance(raw_text, str):
|
||||
try:
|
||||
error_text = _format_error_text(json.loads(raw_text))
|
||||
except Exception:
|
||||
error_text = raw_text
|
||||
else:
|
||||
error_text = str(raw_text)
|
||||
elif hasattr(e, "response") and e.response is not None:
|
||||
status_code = e.response.status_code
|
||||
# Try JSON first, fall back to text
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_text = _format_error_text(error_json)
|
||||
except Exception:
|
||||
error_text = e.response.text
|
||||
|
||||
# Handle specific status codes
|
||||
if status_code == 400:
|
||||
if "does not exist for the field 'project'" in error_text:
|
||||
raise ConnectorValidationError(
|
||||
f"The specified Jira project does not exist or you don't have access to it. "
|
||||
f"JQL query: {jql}. Error: {error_text}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid JQL query. JQL: {jql}. Error: {error_text}"
|
||||
)
|
||||
elif status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credentials are expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to execute JQL query. JQL: {jql}"
|
||||
)
|
||||
|
||||
# Re-raise for other error types
|
||||
raise e
|
||||
|
||||
|
||||
def enhanced_search_ids(
|
||||
jira_client: JIRA, jql: str, nextPageToken: str | None = None
|
||||
) -> tuple[list[str], str | None]:
|
||||
@@ -225,15 +149,8 @@ def enhanced_search_ids(
|
||||
"nextPageToken": nextPageToken,
|
||||
"fields": "id",
|
||||
}
|
||||
try:
|
||||
response = jira_client._session.get(enhanced_search_path, params=params)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
_handle_jira_search_error(e, jql)
|
||||
raise # Explicitly re-raise for type checker, should never reach here
|
||||
|
||||
return [str(issue["id"]) for issue in response_json["issues"]], response_json.get(
|
||||
response = jira_client._session.get(enhanced_search_path, params=params).json()
|
||||
return [str(issue["id"]) for issue in response["issues"]], response.get(
|
||||
"nextPageToken"
|
||||
)
|
||||
|
||||
@@ -315,16 +232,12 @@ def _perform_jql_search_v2(
|
||||
f"Fetching Jira issues with JQL: {jql}, "
|
||||
f"starting at {start}, max results: {max_results}"
|
||||
)
|
||||
try:
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
except JIRAError as e:
|
||||
_handle_jira_search_error(e, jql)
|
||||
raise # Explicitly re-raise for type checker, should never reach here
|
||||
issues = jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
for issue in issues:
|
||||
if isinstance(issue, Issue):
|
||||
|
||||
@@ -55,10 +55,12 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -326,7 +328,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
try:
|
||||
item_json = driveitem.to_json()
|
||||
mime_type = item_json.get("file", {}).get("mimeType")
|
||||
if not mime_type or mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
|
||||
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
|
||||
# for now, this skip must happen before we download the file
|
||||
# Similar to Google Drive, we'll just semi-silently skip excluded image types
|
||||
@@ -386,14 +388,14 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
return None
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
file_ext = get_file_ext(driveitem.name)
|
||||
file_ext = driveitem.name.split(".")[-1]
|
||||
|
||||
if not content_bytes:
|
||||
logger.warning(
|
||||
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
|
||||
)
|
||||
elif file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
# NOTE: this if should probably check mime_type instead
|
||||
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
# NOTE: this if should use is_valid_image_type instead with mime_type
|
||||
image_section, _ = store_image_and_create_section(
|
||||
image_data=content_bytes,
|
||||
file_id=driveitem.id,
|
||||
@@ -416,7 +418,7 @@ def _convert_driveitem_to_document_with_permissions(
|
||||
|
||||
# The only mime type that would be returned by get_image_type_from_bytes that is in
|
||||
# EXCLUDED_IMAGE_TYPES is image/gif.
|
||||
if mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
|
||||
if mime_type in EXCLUDED_IMAGE_TYPES:
|
||||
logger.debug(
|
||||
"Skipping embedded image of excluded type %s for %s",
|
||||
mime_type,
|
||||
@@ -1504,7 +1506,7 @@ class SharepointConnector(
|
||||
)
|
||||
for driveitem in driveitems:
|
||||
driveitem_extension = get_file_ext(driveitem.name)
|
||||
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping {driveitem.web_url} as it is not a supported file type"
|
||||
)
|
||||
@@ -1512,7 +1514,7 @@ class SharepointConnector(
|
||||
|
||||
# Only yield empty documents if they are PDFs or images
|
||||
should_yield_if_empty = (
|
||||
driveitem_extension in OnyxFileExtensions.IMAGE_EXTENSIONS
|
||||
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
or driveitem_extension == ".pdf"
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from oauthlib.oauth2 import BackendApplicationClient
|
||||
from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from playwright.sync_api import TimeoutError
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
@@ -87,8 +86,6 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
# Grace period after page navigation to allow bot-detection challenges to complete
|
||||
BOT_DETECTION_GRACE_PERIOD_MS = 5000
|
||||
|
||||
# Define common headers that mimic a real browser
|
||||
DEFAULT_USER_AGENT = (
|
||||
@@ -557,17 +554,12 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page = session_ctx.playwright_context.new_page()
|
||||
try:
|
||||
# Use "commit" instead of "domcontentloaded" to avoid hanging on bot-detection pages
|
||||
# that may never fire domcontentloaded. "commit" waits only for navigation to be
|
||||
# committed (response received), then we add a short wait for initial rendering.
|
||||
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
|
||||
page_response = page.goto(
|
||||
initial_url,
|
||||
timeout=30000, # 30 seconds
|
||||
wait_until="commit", # Wait for navigation to commit
|
||||
wait_until="domcontentloaded", # Wait for DOM to be ready
|
||||
)
|
||||
# Give the page a moment to start rendering after navigation commits.
|
||||
# Allows CloudFlare and other bot-detection challenges to complete.
|
||||
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified") if page_response else None
|
||||
@@ -592,15 +584,8 @@ class WebConnector(LoadConnector):
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
# Wait for content to load, but catch timeout if page never reaches networkidle
|
||||
# (e.g., CloudFlare protection keeps making requests)
|
||||
try:
|
||||
page.wait_for_load_state(
|
||||
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
|
||||
)
|
||||
except TimeoutError:
|
||||
# If networkidle times out, just give it a moment for content to render
|
||||
time.sleep(1)
|
||||
# wait for the content to load if we scrolled
|
||||
page.wait_for_load_state("networkidle", timeout=30000)
|
||||
time.sleep(0.5) # let javascript run
|
||||
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
|
||||
@@ -39,7 +39,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -830,8 +830,8 @@ def slack_retrieval(
|
||||
)
|
||||
|
||||
# Query slack with entity filtering
|
||||
llm = get_default_llm()
|
||||
query_strings = build_slack_queries(query, llm, entities, available_channels)
|
||||
_, fast_llm = get_default_llms()
|
||||
query_strings = build_slack_queries(query, fast_llm, entities, available_channels)
|
||||
|
||||
# Determine filtering based on entities OR context (bot)
|
||||
include_dm = False
|
||||
|
||||
@@ -234,6 +234,9 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.default_model_name = (
|
||||
llm_provider_upsert_request.default_model_name
|
||||
)
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider_upsert_request.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
|
||||
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
@@ -26,9 +24,7 @@ logger = setup_logger()
|
||||
# MCPServer operations
|
||||
def get_all_mcp_servers(db_session: Session) -> list[MCPServer]:
|
||||
"""Get all MCP servers"""
|
||||
return list(
|
||||
db_session.scalars(select(MCPServer).order_by(MCPServer.created_at)).all()
|
||||
)
|
||||
return list(db_session.scalars(select(MCPServer)).all())
|
||||
|
||||
|
||||
def get_mcp_server_by_id(server_id: int, db_session: Session) -> MCPServer:
|
||||
@@ -128,7 +124,6 @@ def update_mcp_server__no_commit(
|
||||
auth_performer: MCPAuthenticationPerformer | None = None,
|
||||
transport: MCPTransport | None = None,
|
||||
status: MCPServerStatus | None = None,
|
||||
last_refreshed_at: datetime.datetime | None = None,
|
||||
) -> MCPServer:
|
||||
"""Update an existing MCP server"""
|
||||
server = get_mcp_server_by_id(server_id, db_session)
|
||||
@@ -149,8 +144,6 @@ def update_mcp_server__no_commit(
|
||||
server.transport = transport
|
||||
if status is not None:
|
||||
server.status = status
|
||||
if last_refreshed_at is not None:
|
||||
server.last_refreshed_at = last_refreshed_at
|
||||
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return server
|
||||
@@ -337,15 +330,3 @@ def delete_user_connection_configs_for_server(
|
||||
db_session.delete(config)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_all_user_connection_configs_for_server_no_commit(
|
||||
server_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Delete all user connection configs for a specific MCP server"""
|
||||
db_session.execute(
|
||||
delete(MCPConnectionConfig).where(
|
||||
MCPConnectionConfig.mcp_server_id == server_id
|
||||
)
|
||||
)
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
|
||||
99
backend/onyx/db/milestone.py
Normal file
99
backend/onyx/db/milestone.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.models import Milestone
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
USER_ASSISTANT_PREFIX = "user_assistants_used_"
|
||||
MULTI_ASSISTANT_USED = "multi_assistant_used"
|
||||
|
||||
|
||||
def create_milestone(
|
||||
user: User | None,
|
||||
event_type: MilestoneRecordType,
|
||||
db_session: Session,
|
||||
) -> Milestone:
|
||||
milestone = Milestone(
|
||||
event_type=event_type,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(milestone)
|
||||
db_session.commit()
|
||||
|
||||
return milestone
|
||||
|
||||
|
||||
def create_milestone_if_not_exists(
|
||||
user: User | None, event_type: MilestoneRecordType, db_session: Session
|
||||
) -> tuple[Milestone, bool]:
|
||||
# Check if it exists
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if milestone is not None:
|
||||
return milestone, False
|
||||
|
||||
# If it doesn't exist, try to create it.
|
||||
try:
|
||||
milestone = create_milestone(user, event_type, db_session)
|
||||
return milestone, True
|
||||
except IntegrityError:
|
||||
# Another thread or process inserted it in the meantime
|
||||
db_session.rollback()
|
||||
# Fetch again to return the existing record
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one() # Now should exist
|
||||
return milestone, False
|
||||
|
||||
|
||||
def update_user_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
user_id: str | None,
|
||||
assistant_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
milestone.event_tracker = event_tracker = {}
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
# No need to keep tracking and populating if the milestone has already been hit
|
||||
return
|
||||
|
||||
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
|
||||
|
||||
if event_tracker.get(user_key) is None:
|
||||
event_tracker[user_key] = [assistant_id]
|
||||
elif assistant_id not in event_tracker[user_key]:
|
||||
event_tracker[user_key].append(assistant_id)
|
||||
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_multi_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
db_session: Session,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Returns if the milestone was hit and if it was just hit for the first time"""
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
return False, False
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
return True, False
|
||||
|
||||
for key, value in event_tracker.items():
|
||||
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
|
||||
event_tracker[MULTI_ASSISTANT_USED] = True
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
return True, True
|
||||
|
||||
return False, False
|
||||
@@ -2215,8 +2215,6 @@ class ToolCall(Base):
|
||||
# The tools with the same turn number (and parent) were called in parallel
|
||||
# Ones with different turn numbers (and same parent) were called sequentially
|
||||
turn_number: Mapped[int] = mapped_column(Integer)
|
||||
# Index order of tool calls from the LLM for parallel tool calls
|
||||
tab_index: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# Not a FK because we want to be able to delete the tool without deleting
|
||||
# this entry
|
||||
@@ -2384,6 +2382,7 @@ class LLMProvider(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -3674,9 +3673,6 @@ class MCPServer(Base):
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
last_refreshed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
admin_connection_config: Mapped["MCPConnectionConfig | None"] = relationship(
|
||||
@@ -3689,7 +3685,6 @@ class MCPServer(Base):
|
||||
"MCPConnectionConfig",
|
||||
foreign_keys="MCPConnectionConfig.mcp_server_id",
|
||||
back_populates="mcp_server",
|
||||
passive_deletes=True,
|
||||
)
|
||||
current_actions: Mapped[list["Tool"]] = relationship(
|
||||
"Tool", back_populates="mcp_server", cascade="all, delete-orphan"
|
||||
@@ -3918,22 +3913,3 @@ class ExternalGroupPermissionSyncAttempt(Base):
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return self.status.is_terminal()
|
||||
|
||||
|
||||
class License(Base):
|
||||
"""Stores the signed license blob (singleton pattern - only one row)."""
|
||||
|
||||
__tablename__ = "license"
|
||||
__table_args__ = (
|
||||
# Singleton pattern - unique index on constant ensures only one row
|
||||
Index("idx_license_singleton", text("(true)"), unique=True),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
license_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -221,7 +221,6 @@ def create_tool_call_no_commit(
|
||||
parent_tool_call_id: int | None = None,
|
||||
reasoning_tokens: str | None = None,
|
||||
generated_images: list[dict] | None = None,
|
||||
tab_index: int = 0,
|
||||
add_only: bool = True,
|
||||
) -> ToolCall:
|
||||
"""
|
||||
@@ -240,7 +239,6 @@ def create_tool_call_no_commit(
|
||||
parent_tool_call_id: Optional parent tool call ID (for nested tool calls)
|
||||
reasoning_tokens: Optional reasoning tokens
|
||||
generated_images: Optional list of generated image metadata for replay
|
||||
tab_index: Index order of tool calls from the LLM for parallel tool calls
|
||||
commit: If True, commit the transaction; if False, flush only
|
||||
|
||||
Returns:
|
||||
@@ -251,7 +249,6 @@ def create_tool_call_no_commit(
|
||||
parent_chat_message_id=parent_chat_message_id,
|
||||
parent_tool_call_id=parent_tool_call_id,
|
||||
turn_number=turn_number,
|
||||
tab_index=tab_index,
|
||||
tool_id=tool_id,
|
||||
tool_call_id=tool_call_id,
|
||||
reasoning_tokens=reasoning_tokens,
|
||||
|
||||
@@ -42,15 +42,6 @@ def fetch_web_search_provider_by_name(
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def fetch_web_search_provider_by_type(
|
||||
provider_type: WebSearchProviderType, db_session: Session
|
||||
) -> InternetSearchProvider | None:
|
||||
stmt = select(InternetSearchProvider).where(
|
||||
InternetSearchProvider.provider_type == provider_type.value
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def _ensure_unique_search_name(
|
||||
name: str, provider_id: int | None, db_session: Session
|
||||
) -> None:
|
||||
@@ -198,15 +189,6 @@ def fetch_web_content_provider_by_name(
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def fetch_web_content_provider_by_type(
|
||||
provider_type: WebContentProviderType, db_session: Session
|
||||
) -> InternetContentProvider | None:
|
||||
stmt = select(InternetContentProvider).where(
|
||||
InternetContentProvider.provider_type == provider_type.value
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
def _ensure_unique_content_name(
|
||||
name: str, provider_id: int | None, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -12,30 +12,18 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import run_llm_step_pkt_generator
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
|
||||
from onyx.deep_research.utils import check_special_tool_calls
|
||||
from onyx.deep_research.utils import create_think_tool_token_processor
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import FINAL_REPORT_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -43,10 +31,6 @@ from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.fake_tools.research_agent import run_research_agent_calls
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -56,71 +40,7 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_USER_MESSAGES_FOR_CONTEXT = 5
|
||||
# Might be something like:
|
||||
# 1. Research 1-2
|
||||
# 2. Think
|
||||
# 3. Research 3-4
|
||||
# 4. Think
|
||||
# 5. Research 5-6
|
||||
# 6. Think
|
||||
# 7. Research, possibly something new or different from the plan
|
||||
# 8. Think
|
||||
# 9. Generate report
|
||||
MAX_ORCHESTRATOR_CYCLES = 9
|
||||
|
||||
# Similar but without the 4 thinking tool calls
|
||||
MAX_ORCHESTRATOR_CYCLES_REASONING = 5
|
||||
|
||||
|
||||
def generate_final_report(
|
||||
history: list[ChatMessageSimple],
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
state_container: ChatStateContainer,
|
||||
emitter: Emitter,
|
||||
user_identity: LLMUserIdentity | None,
|
||||
) -> None:
|
||||
final_report_prompt = FINAL_REPORT_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=final_report_prompt,
|
||||
token_count=token_counter(final_report_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=USER_FINAL_REPORT_QUERY,
|
||||
token_count=token_counter(USER_FINAL_REPORT_QUERY),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
final_report_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=final_report_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
turn_index=999, # TODO
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
final_report = llm_step_result.answer
|
||||
if final_report is None:
|
||||
raise ValueError("LLM failed to generate the final deep research report")
|
||||
|
||||
state_container.set_answer_tokens(final_report)
|
||||
MAX_ORCHESTRATOR_CYCLES = 8
|
||||
|
||||
|
||||
def run_deep_research_llm_loop(
|
||||
@@ -153,8 +73,7 @@ def run_deep_research_llm_loop(
|
||||
|
||||
# Filter tools to only allow web search, internal search, and open URL
|
||||
allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME}
|
||||
allowed_tools = [tool for tool in tools if tool.name in allowed_tool_names]
|
||||
orchestrator_start_turn_index = 0
|
||||
[tool for tool in tools if tool.name in allowed_tool_names]
|
||||
|
||||
#########################################################
|
||||
# CLARIFICATION STEP (optional)
|
||||
@@ -179,8 +98,7 @@ def run_deep_research_llm_loop(
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
step_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
@@ -188,12 +106,24 @@ def run_deep_research_llm_loop(
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
@@ -224,13 +154,15 @@ def run_deep_research_llm_loop(
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
turn_index=0,
|
||||
citation_processor=None,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
@@ -240,7 +172,6 @@ def run_deep_research_llm_loop(
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
@@ -259,16 +190,7 @@ def run_deep_research_llm_loop(
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, orchestrator_start_turn_index = e.value
|
||||
# TODO: All that is done with the plan is for streaming to the frontend and informing the flow
|
||||
# Currently not saved. It would have to be saved as a ToolCall for a new tool type.
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
turn_index=orchestrator_start_turn_index - 1,
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
@@ -281,12 +203,6 @@ def run_deep_research_llm_loop(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT if not is_reasoning_model else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
@@ -294,19 +210,16 @@ def run_deep_research_llm_loop(
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
reasoning_cycles = 0
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
for cycle in range(MAX_ORCHESTRATOR_CYCLES):
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
|
||||
@@ -326,163 +239,16 @@ def run_deep_research_llm_loop(
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor() if not is_reasoning_model else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
|
||||
turn_index=cycle,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_processor,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
break
|
||||
|
||||
most_recent_reasoning: str | None = None
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
reasoning_cycles += 1
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
|
||||
continue
|
||||
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
break
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
research_agent_calls=research_agent_calls,
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
for tab_index, research_result in enumerate(research_results):
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
|
||||
).id,
|
||||
reasoning_tokens=most_recent_reasoning,
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=research_result.intermediate_report,
|
||||
search_docs=research_result.search_docs,
|
||||
generated_images=None,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
tool_call_message = tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=research_result.intermediate_report,
|
||||
token_count=token_counter(research_result.intermediate_report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
|
||||
if not special_tool_calls.think_tool_call:
|
||||
most_recent_reasoning = None
|
||||
|
||||
@@ -1,139 +1,18 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
|
||||
RESEARCH_AGENT_TOOL_NAME = "research_agent"
|
||||
RESEARCH_AGENT_TASK_KEY = "task"
|
||||
|
||||
GENERATE_REPORT_TOOL_NAME = "generate_report"
|
||||
|
||||
THINK_TOOL_NAME = "think_tool"
|
||||
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
GENERATE_PLAN_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": GENERATE_PLAN_TOOL_NAME,
|
||||
"description": "No clarification needed, generate a research plan for the user's query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RESEARCH_AGENT_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": RESEARCH_AGENT_TOOL_NAME,
|
||||
"description": "Conduct research on a specific topic.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
RESEARCH_AGENT_TASK_KEY: {
|
||||
"type": "string",
|
||||
"description": "The research task to investigate, should be 1-2 descriptive sentences outlining the direction of investigation.",
|
||||
}
|
||||
},
|
||||
"required": [RESEARCH_AGENT_TASK_KEY],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
GENERATE_REPORT_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": GENERATE_REPORT_TOOL_NAME,
|
||||
"description": "Generate the final research report from all of the findings. Should be called when all aspects of the user's query have been researched, or maximum cycles are reached.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
THINK_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": THINK_TOOL_NAME,
|
||||
"description": "Use this for reasoning between research_agent calls and before calling generate_report. Think deeply about key results, identify knowledge gaps, and plan next steps.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your chain of thought reasoning, use paragraph format, no lists.",
|
||||
}
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RESEARCH_AGENT_THINK_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "think_tool",
|
||||
"description": "Use this for reasoning between research steps. Think deeply about key results, identify knowledge gaps, and plan next steps.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"reasoning": {
|
||||
"type": "string",
|
||||
"description": "Your chain of thought reasoning, can be as long as a lengthy paragraph.",
|
||||
}
|
||||
},
|
||||
"required": ["reasoning"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
RESEARCH_AGENT_GENERATE_REPORT_TOOL_DESCRIPTION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "generate_report",
|
||||
"description": "Generate the final research report from all findings. Should be called when research is complete.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
THINK_TOOL_RESPONSE_MESSAGE = "Acknowledged, please continue."
|
||||
THINK_TOOL_RESPONSE_TOKEN_COUNT = 10
|
||||
|
||||
|
||||
def get_clarification_tool_definitions() -> list[dict]:
|
||||
return [GENERATE_PLAN_TOOL_DESCRIPTION]
|
||||
|
||||
|
||||
def get_orchestrator_tools(include_think_tool: bool) -> list[dict]:
|
||||
tools = [
|
||||
RESEARCH_AGENT_TOOL_DESCRIPTION,
|
||||
GENERATE_REPORT_TOOL_DESCRIPTION,
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": GENERATE_PLAN_TOOL_NAME,
|
||||
"description": "No clarification needed, generate a research plan for the user's query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
if include_think_tool:
|
||||
tools.append(THINK_TOOL_DESCRIPTION)
|
||||
return tools
|
||||
|
||||
|
||||
def get_research_agent_additional_tool_definitions(
|
||||
include_think_tool: bool,
|
||||
) -> list[dict]:
|
||||
tools = [GENERATE_REPORT_TOOL_DESCRIPTION]
|
||||
if include_think_tool:
|
||||
tools.append(RESEARCH_AGENT_THINK_TOOL_DESCRIPTION)
|
||||
return tools
|
||||
|
||||
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
class SpecialToolCalls(BaseModel):
|
||||
think_tool_call: ToolCallKickoff | None = None
|
||||
generate_report_tool_call: ToolCallKickoff | None = None
|
||||
|
||||
|
||||
class ResearchAgentCallResult(BaseModel):
|
||||
intermediate_report: str
|
||||
search_docs: list[SearchDoc]
|
||||
@@ -1,168 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.deep_research.dr_mock_tools import GENERATE_REPORT_TOOL_NAME
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_NAME
|
||||
from onyx.deep_research.models import SpecialToolCalls
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.llm.model_response import Delta
|
||||
from onyx.llm.model_response import FunctionCall
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
|
||||
|
||||
# JSON prefixes to detect in think_tool arguments
|
||||
# The schema is: {"reasoning": "...content..."}
|
||||
JSON_PREFIX_WITH_SPACE = '{"reasoning": "'
|
||||
JSON_PREFIX_NO_SPACE = '{"reasoning":"'
|
||||
|
||||
|
||||
class ThinkToolProcessorState(BaseModel):
|
||||
"""State for tracking think tool processing across streaming deltas."""
|
||||
|
||||
think_tool_found: bool = False
|
||||
think_tool_index: int | None = None
|
||||
think_tool_id: str | None = None
|
||||
full_arguments: str = "" # Full accumulated arguments for final tool call
|
||||
accumulated_args: str = "" # Working buffer for JSON parsing
|
||||
json_prefix_stripped: bool = False
|
||||
# Buffer holds content that might be the JSON suffix "}
|
||||
# We hold back 2 chars to avoid emitting the closing "}
|
||||
buffer: str = ""
|
||||
|
||||
|
||||
def _extract_reasoning_chunk(state: ThinkToolProcessorState) -> str | None:
|
||||
"""
|
||||
Extract reasoning content from accumulated arguments, stripping JSON wrapper.
|
||||
|
||||
Returns the next chunk of reasoning to emit, or None if nothing to emit yet.
|
||||
"""
|
||||
# If we haven't found the JSON prefix yet, look for it
|
||||
if not state.json_prefix_stripped:
|
||||
# Try both prefix variants
|
||||
for prefix in [JSON_PREFIX_WITH_SPACE, JSON_PREFIX_NO_SPACE]:
|
||||
prefix_pos = state.accumulated_args.find(prefix)
|
||||
if prefix_pos != -1:
|
||||
# Found prefix - extract content after it
|
||||
content_start = prefix_pos + len(prefix)
|
||||
state.buffer = state.accumulated_args[content_start:]
|
||||
state.accumulated_args = ""
|
||||
state.json_prefix_stripped = True
|
||||
break
|
||||
|
||||
if not state.json_prefix_stripped:
|
||||
# Haven't seen full prefix yet, keep accumulating
|
||||
return None
|
||||
else:
|
||||
# Already stripped prefix, add new content to buffer
|
||||
state.buffer += state.accumulated_args
|
||||
state.accumulated_args = ""
|
||||
|
||||
# Hold back last 2 chars in case they're the JSON suffix "}
|
||||
if len(state.buffer) <= 2:
|
||||
return None
|
||||
|
||||
# Emit everything except last 2 chars
|
||||
to_emit = state.buffer[:-2]
|
||||
state.buffer = state.buffer[-2:]
|
||||
|
||||
return to_emit if to_emit else None
|
||||
|
||||
|
||||
def create_think_tool_token_processor() -> (
|
||||
Callable[[Delta | None, Any], tuple[Delta | None, Any]]
|
||||
):
|
||||
"""
|
||||
Create a custom token processor that converts think_tool calls to reasoning content.
|
||||
|
||||
When the think_tool is detected:
|
||||
- Tool call arguments are converted to reasoning_content (JSON wrapper stripped)
|
||||
- All other deltas (content, other tool calls) are dropped
|
||||
|
||||
This allows non-reasoning models to emit chain-of-thought via the think_tool,
|
||||
which gets displayed as reasoning tokens in the UI.
|
||||
|
||||
Returns:
|
||||
A function compatible with run_llm_step_pkt_generator's custom_token_processor parameter.
|
||||
The function takes (Delta, state) and returns (modified Delta | None, new state).
|
||||
"""
|
||||
|
||||
def process_token(delta: Delta | None, state: Any) -> tuple[Delta | None, Any]:
|
||||
if state is None:
|
||||
state = ThinkToolProcessorState()
|
||||
|
||||
# Handle flush signal (delta=None) - emit the complete tool call
|
||||
if delta is None:
|
||||
if state.think_tool_found and state.think_tool_id:
|
||||
# Return the complete think tool call
|
||||
complete_tool_call = ChatCompletionDeltaToolCall(
|
||||
id=state.think_tool_id,
|
||||
index=state.think_tool_index or 0,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=THINK_TOOL_NAME,
|
||||
arguments=state.full_arguments,
|
||||
),
|
||||
)
|
||||
return Delta(tool_calls=[complete_tool_call]), state
|
||||
return None, state
|
||||
|
||||
# Check for think tool in tool_calls
|
||||
if delta.tool_calls:
|
||||
for tool_call in delta.tool_calls:
|
||||
# Detect think tool by name
|
||||
if tool_call.function and tool_call.function.name == THINK_TOOL_NAME:
|
||||
state.think_tool_found = True
|
||||
state.think_tool_index = tool_call.index
|
||||
|
||||
# Capture tool call id when available
|
||||
if (
|
||||
state.think_tool_found
|
||||
and tool_call.index == state.think_tool_index
|
||||
and tool_call.id
|
||||
):
|
||||
state.think_tool_id = tool_call.id
|
||||
|
||||
# Accumulate arguments for the think tool
|
||||
if (
|
||||
state.think_tool_found
|
||||
and tool_call.index == state.think_tool_index
|
||||
and tool_call.function
|
||||
and tool_call.function.arguments
|
||||
):
|
||||
# Track full arguments for final tool call
|
||||
state.full_arguments += tool_call.function.arguments
|
||||
# Also accumulate for JSON parsing
|
||||
state.accumulated_args += tool_call.function.arguments
|
||||
|
||||
# Try to extract reasoning content
|
||||
reasoning_chunk = _extract_reasoning_chunk(state)
|
||||
if reasoning_chunk:
|
||||
# Return delta with reasoning_content to trigger reasoning streaming
|
||||
return Delta(reasoning_content=reasoning_chunk), state
|
||||
|
||||
# If think tool found, drop all other content
|
||||
if state.think_tool_found:
|
||||
return None, state
|
||||
|
||||
# No think tool detected, pass through original delta
|
||||
return delta, state
|
||||
|
||||
return process_token
|
||||
|
||||
|
||||
def check_special_tool_calls(tool_calls: list[ToolCallKickoff]) -> SpecialToolCalls:
|
||||
think_tool_call: ToolCallKickoff | None = None
|
||||
generate_report_tool_call: ToolCallKickoff | None = None
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name == THINK_TOOL_NAME:
|
||||
think_tool_call = tool_call
|
||||
elif tool_call.tool_name == GENERATE_REPORT_TOOL_NAME:
|
||||
generate_report_tool_call = tool_call
|
||||
|
||||
return SpecialToolCalls(
|
||||
think_tool_call=think_tool_call,
|
||||
generate_report_tool_call=generate_report_tool_call,
|
||||
)
|
||||
@@ -1,4 +1,6 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -10,11 +12,13 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
# NOTE: "Document" in the naming convention is used to refer to the entire
|
||||
# document as represented in Onyx. What is actually stored in the index is the
|
||||
# document chunks. By the terminology of most search engines / vector databases,
|
||||
# the individual objects stored are called documents, but in this case it refers
|
||||
# to a chunk.
|
||||
# NOTE: "Document" in the naming convention is used to refer to the entire document as represented in Onyx.
|
||||
# What is actually stored in the index is the document chunks. By the terminology of most search engines / vector
|
||||
# databases, the individual objects stored are called documents, but in this case it refers to a chunk.
|
||||
|
||||
# Outside of searching and update capabilities, the document index must also implement the ability to port all of
|
||||
# the documents over to a secondary index. This allows for embedding models to be updated and for porting documents
|
||||
# to happen in the background while the primary index still serves the main traffic.
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -38,7 +42,7 @@ __all__ = [
|
||||
|
||||
class DocumentInsertionRecord(BaseModel):
|
||||
"""
|
||||
Result of indexing a document.
|
||||
Result of indexing a document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
@@ -48,11 +52,10 @@ class DocumentInsertionRecord(BaseModel):
|
||||
|
||||
|
||||
class DocumentSectionRequest(BaseModel):
|
||||
"""Request for a document section or whole document.
|
||||
|
||||
If no min_chunk_ind is provided it should start at the beginning of the
|
||||
document.
|
||||
If no max_chunk_ind is provided it should go to the end of the document.
|
||||
"""
|
||||
Request for a document section or whole document
|
||||
If no min_chunk_ind is provided it should start at the beginning of the document
|
||||
If no max_chunk_ind is provided it should go to the end of the document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
@@ -64,37 +67,24 @@ class DocumentSectionRequest(BaseModel):
|
||||
|
||||
class IndexingMetadata(BaseModel):
|
||||
"""
|
||||
Information about chunk counts for efficient cleaning / updating of document
|
||||
chunks.
|
||||
|
||||
A common pattern to ensure that no chunks are left over is to delete all of
|
||||
the chunks for a document and then re-index the document. This information
|
||||
allows us to only delete the extra "tail" chunks when the document has
|
||||
gotten shorter.
|
||||
Information about chunk counts for efficient cleaning / updating of document chunks. A common pattern to ensure
|
||||
that no chunks are left over is to delete all of the chunks for a document and then re-index the document. This
|
||||
information allows us to only delete the extra "tail" chunks when the document has gotten shorter.
|
||||
"""
|
||||
|
||||
class ChunkCounts(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
old_chunk_cnt: int
|
||||
new_chunk_cnt: int
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
doc_id_to_chunk_cnt_diff: dict[str, ChunkCounts]
|
||||
# The tuple is (old_chunk_cnt, new_chunk_cnt)
|
||||
doc_id_to_chunk_cnt_diff: dict[str, tuple[int, int]]
|
||||
|
||||
|
||||
class MetadataUpdateRequest(BaseModel):
|
||||
"""
|
||||
Updates to the documents that can happen without there being an update to
|
||||
the contents of the document.
|
||||
Updates to the documents that can happen without there being an update to the contents of the document.
|
||||
"""
|
||||
|
||||
document_ids: list[str]
|
||||
# Passed in to help with potential optimizations of the implementation. The
|
||||
# keys should be redundant with document_ids.
|
||||
# Passed in to help with potential optimizations of the implementation
|
||||
doc_id_to_chunk_cnt: dict[str, int]
|
||||
# For the ones that are None, there is no update required to that field.
|
||||
# For the ones that are None, there is no update required to that field
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
@@ -110,6 +100,17 @@ class SchemaVerifiable(abc.ABC):
|
||||
all valid in the schema.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_id: int | None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_name = index_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abc.abstractmethod
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
@@ -130,40 +131,34 @@ class SchemaVerifiable(abc.ABC):
|
||||
|
||||
class Indexable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to index document chunks.
|
||||
Class must implement the ability to index document chunks
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterator[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index.
|
||||
Takes a list of document chunks and indexes them in the document index. This is often a batch operation
|
||||
including chunks from multiple documents.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
NOTE: When a document is reindexed/updated here and has gotten shorter, it is important to delete the extra
|
||||
chunks at the end to ensure there are no stale chunks in the index.
|
||||
|
||||
NOTE: When a document is reindexed/updated here and has gotten shorter,
|
||||
it is important to delete the extra chunks at the end to ensure there
|
||||
are no stale chunks in the index. The implementation should do this.
|
||||
NOTE: The chunks of a document are never separated into separate index() calls. So there is
|
||||
no worry of receiving the first 0 through n chunks in one index call and the next n through
|
||||
m chunks of a document in the next index call.
|
||||
|
||||
NOTE: The chunks of a document are never separated into separate index()
|
||||
calls. So there is no worry of receiving the first 0 through n chunks in
|
||||
one index call and the next n through m chunks of a document in the next
|
||||
index call.
|
||||
|
||||
Args:
|
||||
chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
indexing_metadata: Information about chunk counts for efficient
|
||||
cleaning / updating.
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document index.
|
||||
- indexing_metadata: Information about chunk counts for efficient cleaning / updating
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for
|
||||
deduping chunks when updating, as well as if the document is newly
|
||||
indexed or already existed and just updated.
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
when updating, as well as if the document is newly indexed or already existed and
|
||||
just updated
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -178,6 +173,7 @@ class Deletable(abc.ABC):
|
||||
def delete(
|
||||
self,
|
||||
db_doc_id: str,
|
||||
*,
|
||||
# Passed in in case it helps the efficiency of the delete implementation
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
@@ -315,10 +311,10 @@ class DocumentIndex(
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
A valid document index that can plug into all Onyx flows must implement all
|
||||
of these functionalities.
|
||||
A valid document index that can plug into all Onyx flows must implement all of these
|
||||
functionalities.
|
||||
|
||||
As a high-level summary, document indices need to be able to:
|
||||
As a high level summary, document indices need to be able to
|
||||
- Verify the schema definition is valid
|
||||
- Index new documents
|
||||
- Update specific attributes of existing documents
|
||||
|
||||
@@ -39,14 +39,6 @@ def delete_vespa_chunks(
|
||||
http_client: httpx.Client,
|
||||
executor: concurrent.futures.ThreadPoolExecutor | None = None,
|
||||
) -> None:
|
||||
"""Deletes a list of chunks from a Vespa index in parallel.
|
||||
|
||||
Args:
|
||||
doc_chunk_ids: List of chunk IDs to delete.
|
||||
index_name: Name of the index to delete from.
|
||||
http_client: HTTP client to use for the request.
|
||||
executor: Executor to use for the request.
|
||||
"""
|
||||
external_executor = True
|
||||
|
||||
if not executor:
|
||||
|
||||
@@ -36,9 +36,7 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import (
|
||||
DocumentInsertionRecord as OldDocumentInsertionRecord,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
@@ -46,7 +44,6 @@ from onyx.document_index.interfaces import UpdateRequest
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.vespa.chunk_retrieval import batch_search_api_retrieval
|
||||
from onyx.document_index.vespa.chunk_retrieval import (
|
||||
parallel_visit_api_retrieval,
|
||||
@@ -54,7 +51,9 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
@@ -64,8 +63,6 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
build_vespa_filters,
|
||||
)
|
||||
from onyx.document_index.vespa.vespa_document_index import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import BATCH_SIZE
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
@@ -256,10 +253,8 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
self.multitenant = multitenant
|
||||
|
||||
# Temporary until we refactor the entirety of this class.
|
||||
self.httpx_client = httpx_client
|
||||
|
||||
self.httpx_client_context: BaseHTTPXClientContext
|
||||
|
||||
if httpx_client:
|
||||
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
|
||||
else:
|
||||
@@ -480,45 +475,92 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
raise ValueError("Bug: Length of doc ID to chunk maps does not match.")
|
||||
doc_id_to_chunk_cnt_diff = {
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=index_batch_params.doc_id_to_previous_chunk_cnt[doc_id],
|
||||
new_chunk_cnt=index_batch_params.doc_id_to_new_chunk_cnt[doc_id],
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
|
||||
with updating the associated permissions. Assumes that a document will not be split into
|
||||
multiple chunk batches calling this function multiple times, otherwise only the last set of
|
||||
chunks will be kept"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt = index_batch_params.doc_id_to_previous_chunk_cnt
|
||||
doc_id_to_new_chunk_cnt = index_batch_params.doc_id_to_new_chunk_cnt
|
||||
tenant_id = index_batch_params.tenant_id
|
||||
large_chunks_enabled = index_batch_params.large_chunks_enabled
|
||||
|
||||
# IMPORTANT: This must be done one index at a time, do not use secondary index here
|
||||
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
|
||||
|
||||
# needed so the final DocumentInsertionRecord returned can have the original document ID
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
for ind, chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[ind]
|
||||
new_document_id_to_original_document_id[chunk.source_document.id] = (
|
||||
old_chunk.source_document.id
|
||||
)
|
||||
for doc_id in index_batch_params.doc_id_to_previous_chunk_cnt.keys()
|
||||
}
|
||||
indexing_metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff=doc_id_to_chunk_cnt_diff,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=index_batch_params.tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
# This conversion from list to set only to be converted again to a list
|
||||
# upstream is suboptimal and only temporary until we refactor the
|
||||
# entirety of this class.
|
||||
document_insertion_records = vespa_document_index.index(
|
||||
chunks, indexing_metadata
|
||||
)
|
||||
return set(
|
||||
[
|
||||
OldDocumentInsertionRecord(
|
||||
document_id=doc_insertion_record.document_id,
|
||||
already_existed=doc_insertion_record.already_existed,
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
self.httpx_client_context as http_client,
|
||||
):
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
# documents that have `chunk_count` in the database, but not for
|
||||
# `old_version` documents.
|
||||
|
||||
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = [
|
||||
VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=doc_id_to_previous_chunk_cnt.get(doc_id, 0),
|
||||
new_chunk_count=doc_id_to_new_chunk_cnt.get(doc_id, 0),
|
||||
)
|
||||
for doc_insertion_record in document_insertion_records
|
||||
for doc_id in doc_id_to_new_chunk_cnt.keys()
|
||||
]
|
||||
)
|
||||
|
||||
for cleaned_doc_info in enriched_doc_infos:
|
||||
# If the document has previously indexed chunks, we know it previously existed
|
||||
if cleaned_doc_info.chunk_end_index:
|
||||
existing_docs.add(cleaned_doc_info.doc_id)
|
||||
|
||||
# Now, for each doc, we know exactly where to start and end our deletion
|
||||
# So let's generate the chunk IDs for each chunk to delete
|
||||
chunks_to_delete = get_document_chunk_ids(
|
||||
enriched_document_info_list=enriched_doc_infos,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=large_chunks_enabled,
|
||||
)
|
||||
|
||||
# Delete old Vespa documents
|
||||
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
|
||||
delete_vespa_chunks(
|
||||
doc_chunk_ids=doc_chunk_ids_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
multitenant=self.multitenant,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids = {chunk.source_document.id for chunk in cleaned_chunks}
|
||||
|
||||
return {
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
already_existed=cleaned_doc_id in existing_docs,
|
||||
)
|
||||
for cleaned_doc_id in all_cleaned_doc_ids
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _apply_updates_batched(
|
||||
|
||||
@@ -244,15 +244,6 @@ def batch_index_vespa_chunks(
|
||||
multitenant: bool,
|
||||
executor: concurrent.futures.ThreadPoolExecutor | None = None,
|
||||
) -> None:
|
||||
"""Indexes a list of chunks in a Vespa index in parallel.
|
||||
|
||||
Args:
|
||||
chunks: List of chunks to index.
|
||||
index_name: Name of the index to index into.
|
||||
http_client: HTTP client to use for the request.
|
||||
multitenant: Whether the index is multitenant.
|
||||
executor: Executor to use for the request.
|
||||
"""
|
||||
external_executor = True
|
||||
|
||||
if not executor:
|
||||
|
||||
@@ -64,9 +64,10 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
"""
|
||||
|
||||
return httpx.Client(
|
||||
cert=(
|
||||
cast(tuple[str, str], (VESPA_CLOUD_CERT_PATH, VESPA_CLOUD_KEY_PATH))
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
import concurrent.futures
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import DocumentIndex
|
||||
from onyx.document_index.interfaces_new import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces_new import DocumentSectionRequest
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import BATCH_SIZE
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.batching import batch_generator
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
class TenantState(BaseModel):
|
||||
"""
|
||||
Captures the tenant-related state for an instance of VespaDocumentIndex.
|
||||
|
||||
TODO(andrei): If we find that we need this for Opensearch too, just move
|
||||
this to interfaces_new.py.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
tenant_id: str
|
||||
multitenant: bool
|
||||
|
||||
|
||||
def _enrich_basic_chunk_info(
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
document_id: str,
|
||||
previous_chunk_count: int | None,
|
||||
new_chunk_count: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Determines which chunks need to be deleted during document reindexing.
|
||||
|
||||
When a document is reindexed, it may have fewer chunks than before. This
|
||||
function identifies the range of old chunks that need to be deleted by
|
||||
comparing the new chunk count with the previous chunk count.
|
||||
|
||||
Example:
|
||||
If a document previously had 10 chunks (0-9) and now has 7 chunks (0-6),
|
||||
this function identifies that chunks 7-9 need to be deleted.
|
||||
|
||||
Args:
|
||||
index_name: The Vespa index/schema name.
|
||||
http_client: HTTP client for making requests to Vespa.
|
||||
document_id: The Vespa-sanitized ID of the document being reindexed.
|
||||
previous_chunk_count: The total number of chunks the document had before
|
||||
reindexing. None for documents using the legacy chunk ID system.
|
||||
new_chunk_count: The total number of chunks the document has after
|
||||
reindexing. This becomes the starting index for deletion since
|
||||
chunks are 0-indexed.
|
||||
|
||||
Returns:
|
||||
EnrichedDocumentIndexingInfo with chunk_start_index set to
|
||||
new_chunk_count (where deletion begins) and chunk_end_index set to
|
||||
previous_chunk_count (where deletion ends).
|
||||
"""
|
||||
# Technically last indexed chunk index +1.
|
||||
last_indexed_chunk = previous_chunk_count
|
||||
# If the document has no `chunk_count` in the database, we know that it
|
||||
# has the old chunk ID system and we must check for the final chunk index.
|
||||
is_old_version = False
|
||||
if last_indexed_chunk is None:
|
||||
is_old_version = True
|
||||
minimal_doc_info = MinimalDocumentIndexingInfo(
|
||||
doc_id=document_id, chunk_start_index=new_chunk_count
|
||||
)
|
||||
last_indexed_chunk = check_for_final_chunk_existence(
|
||||
minimal_doc_info=minimal_doc_info,
|
||||
start_index=new_chunk_count,
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
enriched_doc_info = EnrichedDocumentIndexingInfo(
|
||||
doc_id=document_id,
|
||||
chunk_start_index=new_chunk_count,
|
||||
chunk_end_index=last_indexed_chunk,
|
||||
old_version=is_old_version,
|
||||
)
|
||||
return enriched_doc_info
|
||||
|
||||
|
||||
class VespaDocumentIndex(DocumentIndex):
|
||||
"""Vespa-specific implementation of the DocumentIndex interface.
|
||||
|
||||
This class provides document indexing, retrieval, and management operations
|
||||
for a Vespa search engine instance. It handles the complete lifecycle of
|
||||
document chunks within a specific Vespa index/schema.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
large_chunks_enabled: bool,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> None:
|
||||
self._index_name = index_name
|
||||
self._tenant_id = tenant_state.tenant_id
|
||||
self._large_chunks_enabled = large_chunks_enabled
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This
|
||||
# is beneficial for indexing / updates / deletes since we have to make a
|
||||
# large volume of requests.
|
||||
self._httpx_client_context: BaseHTTPXClientContext
|
||||
if httpx_client:
|
||||
# Use the provided client. Because this client is presumed global,
|
||||
# it does not close after exiting a context manager.
|
||||
self._httpx_client_context = GlobalHTTPXClientContext(httpx_client)
|
||||
else:
|
||||
# We did not receive a client, so create one what will close after
|
||||
# exiting a context manager.
|
||||
self._httpx_client_context = TemporaryHTTPXClientContext(
|
||||
get_vespa_http_client
|
||||
)
|
||||
self._multitenant = tenant_state.multitenant
|
||||
|
||||
def verify_and_create_index_if_necessary(
|
||||
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
doc_id_to_previous_chunk_cnt = {
|
||||
doc_id: chunk_cnt_diff.old_chunk_cnt
|
||||
for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items()
|
||||
}
|
||||
doc_id_to_new_chunk_cnt = {
|
||||
doc_id: chunk_cnt_diff.new_chunk_cnt
|
||||
for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items()
|
||||
}
|
||||
assert (
|
||||
len(doc_id_to_chunk_cnt_diff)
|
||||
== len(doc_id_to_previous_chunk_cnt)
|
||||
== len(doc_id_to_new_chunk_cnt)
|
||||
), "Bug: Doc ID to chunk maps have different lengths."
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
self._httpx_client_context as http_client,
|
||||
):
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
# documents that have `chunk_count` in the database, but not for
|
||||
# `old_version` documents.
|
||||
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = [
|
||||
_enrich_basic_chunk_info(
|
||||
index_name=self._index_name,
|
||||
http_client=http_client,
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=doc_id_to_previous_chunk_cnt[doc_id],
|
||||
new_chunk_count=doc_id_to_new_chunk_cnt[doc_id],
|
||||
)
|
||||
for doc_id in doc_id_to_chunk_cnt_diff.keys()
|
||||
]
|
||||
|
||||
for enriched_doc_info in enriched_doc_infos:
|
||||
# If the document has previously indexed chunks, we know it
|
||||
# previously existed and this is a reindex.
|
||||
if enriched_doc_info.chunk_end_index:
|
||||
existing_docs.add(enriched_doc_info.doc_id)
|
||||
|
||||
# Now, for each doc, we know exactly where to start and end our
|
||||
# deletion. So let's generate the chunk IDs for each chunk to
|
||||
# delete.
|
||||
# WARNING: This code seems to use
|
||||
# indexing_metadata.doc_id_to_chunk_cnt_diff as the source of truth
|
||||
# for which chunks to delete. This implies that the onus is on the
|
||||
# caller to ensure doc_id_to_chunk_cnt_diff only contains docs
|
||||
# relevant to the chunks argument to this method. This should not be
|
||||
# the contract of DocumentIndex; and this code is only a refactor
|
||||
# from old code. It would seem we should use all_cleaned_doc_ids as
|
||||
# the source of truth.
|
||||
chunks_to_delete = get_document_chunk_ids(
|
||||
enriched_document_info_list=enriched_doc_infos,
|
||||
tenant_id=self._tenant_id, # TODO: Figure out this typing bro wtf.
|
||||
large_chunks_enabled=self._large_chunks_enabled,
|
||||
)
|
||||
|
||||
# Delete old Vespa documents.
|
||||
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
|
||||
delete_vespa_chunks(
|
||||
doc_chunk_ids=doc_chunk_ids_batch,
|
||||
index_name=self._index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
http_client=http_client,
|
||||
multitenant=self._multitenant,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
already_existed=cleaned_doc_id in existing_docs,
|
||||
)
|
||||
for cleaned_doc_id in all_cleaned_doc_ids
|
||||
]
|
||||
|
||||
def delete(self, db_doc_id: str, chunk_count: int | None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def id_based_retrieval(
|
||||
self, chunk_requests: list[DocumentSectionRequest]
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters | None = None,
|
||||
num_to_retrieve: int = 100,
|
||||
dirty: bool | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
@@ -26,9 +26,9 @@ DOCUMENT_ID_ENDPOINT = (
|
||||
|
||||
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
|
||||
|
||||
# Since Vespa doesn't allow batching of inserts / updates, we use threads to
|
||||
# parallelize the operations.
|
||||
NUM_THREADS = 32
|
||||
NUM_THREADS = (
|
||||
32 # since Vespa doesn't allow batching of inserts / updates, we use threads
|
||||
)
|
||||
MAX_ID_SEARCH_QUERY_SIZE = 400
|
||||
# Suspect that adding too many "or" conditions will cause Vespa to timeout and return
|
||||
# an empty list of hits (with no error status and coverage: 0 and degraded)
|
||||
@@ -37,9 +37,7 @@ MAX_OR_CONDITIONS = 10
|
||||
# in the long term, we are looking to improve the performance of Vespa
|
||||
# so that we can bring this back to default
|
||||
VESPA_TIMEOUT = "10s"
|
||||
# The size of the batch to use for batched operations like inserts / updates.
|
||||
# The batch will likely be sent to a threadpool of size NUM_THREADS.
|
||||
BATCH_SIZE = 128
|
||||
BATCH_SIZE = 128 # Specific to Vespa
|
||||
|
||||
TENANT_ID = "tenant_id"
|
||||
DOCUMENT_ID = "document_id"
|
||||
|
||||
@@ -8,6 +8,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from email.parser import Parser as EmailParser
|
||||
from enum import auto
|
||||
from enum import IntFlag
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -23,21 +25,65 @@ from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_processing.file_types import PRESENTATION_MIME_TYPE
|
||||
from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE
|
||||
from onyx.file_processing.file_validation import TEXT_MIME_TYPE
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
from onyx.utils.file_types import PRESENTATION_MIME_TYPE
|
||||
from onyx.utils.file_types import WORD_PROCESSING_MIME_TYPE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from markitdown import MarkItDown
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE(rkuo): Unify this with upload_files_for_chat and file_valiation.py
|
||||
TEXT_SECTION_SEPARATOR = "\n\n"
|
||||
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
]
|
||||
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = (
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
)
|
||||
|
||||
IMAGE_MEDIA_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
_MARKITDOWN_CONVERTER: Optional["MarkItDown"] = None
|
||||
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
@@ -56,11 +102,42 @@ def get_markitdown_converter() -> "MarkItDown":
|
||||
return _MARKITDOWN_CONVERTER
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
Document = auto()
|
||||
Multimedia = auto()
|
||||
All = Plain | Document | Multimedia
|
||||
|
||||
|
||||
def is_text_file_extension(file_name: str) -> bool:
|
||||
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
|
||||
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
_, extension = os.path.splitext(file_path_or_name)
|
||||
return extension.lower()
|
||||
|
||||
|
||||
def is_valid_media_type(media_type: str) -> bool:
|
||||
return media_type in IMAGE_MEDIA_TYPES
|
||||
|
||||
|
||||
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
|
||||
if ext_type & OnyxExtensionType.Plain:
|
||||
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Document:
|
||||
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Multimedia:
|
||||
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_text_file(file: IO[bytes]) -> bool:
|
||||
"""
|
||||
checks if the first 1024 bytes only contain printable or whitespace characters
|
||||
@@ -497,7 +574,9 @@ def extract_file_text(
|
||||
if extension is None:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
|
||||
if is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
func = extension_to_function.get(extension, file_io_to_text)
|
||||
file.seek(0)
|
||||
return func(file)
|
||||
@@ -548,23 +627,6 @@ def extract_text_and_images(
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
|
||||
Args:
|
||||
file: File-like object to extract content from.
|
||||
file_name: Name of the file (used to determine extension/type).
|
||||
pdf_pass: Optional password for encrypted PDFs.
|
||||
content_type: Optional MIME type override for the file.
|
||||
image_callback: Optional callback for streaming image extraction. When provided,
|
||||
embedded images are passed to this callback one at a time as (bytes, filename)
|
||||
instead of being accumulated in the returned ExtractionResult.embedded_images
|
||||
list. This is a memory optimization for large documents with many images -
|
||||
the caller can process/store each image immediately rather than holding all
|
||||
images in memory. When using a callback, ExtractionResult.embedded_images
|
||||
will be an empty list.
|
||||
|
||||
Returns:
|
||||
ExtractionResult containing text_content, embedded_images (empty if callback used),
|
||||
and metadata extracted from the file.
|
||||
"""
|
||||
res = _extract_text_and_images(
|
||||
file, file_name, pdf_pass, content_type, image_callback
|
||||
@@ -601,7 +663,7 @@ def _extract_text_and_images(
|
||||
# with content types in UploadMimeTypes.DOCUMENT_MIME_TYPES as plain text files.
|
||||
# As a result, the file name extension may differ from the original content type.
|
||||
# We process files with a plain text content type first to handle this scenario.
|
||||
if content_type in OnyxMimeTypes.TEXT_MIME_TYPES:
|
||||
if content_type == TEXT_MIME_TYPE:
|
||||
return extract_result_from_text_file(file)
|
||||
|
||||
# Default processing
|
||||
@@ -663,7 +725,7 @@ def _extract_text_and_images(
|
||||
)
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
|
||||
if is_text_file_extension(file_name):
|
||||
return extract_result_from_text_file(file)
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
PRESENTATION_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
)
|
||||
|
||||
SPREADSHEET_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
WORD_PROCESSING_MIME_TYPE = (
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
PDF_MIME_TYPE = "application/pdf"
|
||||
PLAIN_TEXT_MIME_TYPE = "text/plain"
|
||||
|
||||
|
||||
class OnyxMimeTypes:
|
||||
IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"}
|
||||
CSV_MIME_TYPES = {"text/csv"}
|
||||
TEXT_MIME_TYPES = {
|
||||
PLAIN_TEXT_MIME_TYPE,
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
"text/tab-separated-values",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
}
|
||||
DOCUMENT_MIME_TYPES = {
|
||||
PDF_MIME_TYPE,
|
||||
WORD_PROCESSING_MIME_TYPE,
|
||||
PRESENTATION_MIME_TYPE,
|
||||
SPREADSHEET_MIME_TYPE,
|
||||
"message/rfc822",
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union(
|
||||
TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, CSV_MIME_TYPES
|
||||
)
|
||||
|
||||
EXCLUDED_IMAGE_TYPES = {
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
}
|
||||
|
||||
|
||||
class OnyxFileExtensions:
|
||||
PLAIN_TEXT_EXTENSIONS = {
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
".conf",
|
||||
".log",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".xml",
|
||||
".yml",
|
||||
".yaml",
|
||||
".sql",
|
||||
}
|
||||
DOCUMENT_EXTENSIONS = {
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
}
|
||||
IMAGE_EXTENSIONS = {
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
}
|
||||
|
||||
TEXT_AND_DOCUMENT_EXTENSIONS = PLAIN_TEXT_EXTENSIONS.union(DOCUMENT_EXTENSIONS)
|
||||
|
||||
ALL_ALLOWED_EXTENSIONS = TEXT_AND_DOCUMENT_EXTENSIONS.union(IMAGE_EXTENSIONS)
|
||||
55
backend/onyx/file_processing/file_validation.py
Normal file
55
backend/onyx/file_processing/file_validation.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Centralized file type validation utilities.
|
||||
"""
|
||||
|
||||
# NOTE(rkuo): Unify this with upload_files_for_chat and extract_file_text
|
||||
|
||||
# Standard image MIME types supported by most vision LLMs
|
||||
IMAGE_MIME_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/jpg",
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
# Image types that should be excluded from processing
|
||||
EXCLUDED_IMAGE_TYPES = [
|
||||
"image/bmp",
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
# Text MIME types
|
||||
TEXT_MIME_TYPE = "text/plain"
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if mime_type is a valid image type.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_vision_llm(mime_type: str) -> bool:
|
||||
"""
|
||||
Check if this image type can be processed by vision LLMs.
|
||||
|
||||
Args:
|
||||
mime_type: The MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type is supported by vision LLMs, False otherwise
|
||||
"""
|
||||
return mime_type in IMAGE_MIME_TYPES
|
||||
@@ -126,15 +126,6 @@ class FileStore(ABC):
|
||||
Read the file record by the ID
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_file_size(
|
||||
self, file_id: str, db_session: Session | None = None
|
||||
) -> int | None:
|
||||
"""
|
||||
Get the size of a file in bytes.
|
||||
Optionally provide a db_session for database access.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""
|
||||
@@ -431,27 +422,6 @@ class S3BackedFileStore(FileStore):
|
||||
)
|
||||
return file_record
|
||||
|
||||
def get_file_size(
|
||||
self, file_id: str, db_session: Session | None = None
|
||||
) -> int | None:
|
||||
"""
|
||||
Get the size of a file in bytes by querying S3 metadata.
|
||||
"""
|
||||
try:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
|
||||
s3_client = self._get_s3_client()
|
||||
response = s3_client.head_object(
|
||||
Bucket=file_record.bucket_name, Key=file_record.object_key
|
||||
)
|
||||
return response.get("ContentLength")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -140,7 +140,7 @@ class UserFileIndexingAdapter:
|
||||
|
||||
# Initialize tokenizer used for token count calculation
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.kg.utils.formatting_utils import make_entity_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_type_id
|
||||
from onyx.kg.vespa.vespa_interactions import get_document_vespa_contents
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
@@ -415,9 +415,9 @@ def kg_classify_document(
|
||||
)
|
||||
|
||||
# classify with LLM
|
||||
llm = get_default_llm()
|
||||
primary_llm, _ = get_default_llms()
|
||||
try:
|
||||
raw_classification_result = llm_response_to_string(llm.invoke(prompt))
|
||||
raw_classification_result = llm_response_to_string(primary_llm.invoke(prompt))
|
||||
classification_result = (
|
||||
raw_classification_result.replace("```json", "").replace("```", "").strip()
|
||||
)
|
||||
@@ -479,9 +479,9 @@ def kg_deep_extract_chunks(
|
||||
).replace("---content---", llm_context)
|
||||
|
||||
# extract with LLM
|
||||
llm = get_default_llm()
|
||||
_, fast_llm = get_default_llms()
|
||||
try:
|
||||
raw_extraction_result = llm_response_to_string(llm.invoke(prompt))
|
||||
raw_extraction_result = llm_response_to_string(fast_llm.invoke(prompt))
|
||||
cleaned_response = (
|
||||
raw_extraction_result.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
|
||||
@@ -258,9 +258,6 @@ class LitellmLLM(LLM):
|
||||
)
|
||||
|
||||
# Needed to get reasoning tokens from the model
|
||||
# NOTE: OpenAI Responses API is disabled for parallel tool calls because LiteLLM's transformation layer
|
||||
# doesn't properly pass parallel_tool_calls to the API, causing the model to
|
||||
# always return sequential tool calls. For this reason parallel tool calls won't work with OpenAI models
|
||||
if (
|
||||
is_true_openai_model(self.config.model_provider, self.config.model_name)
|
||||
or self.config.model_provider == AZURE_PROVIDER_NAME
|
||||
@@ -293,7 +290,7 @@ class LitellmLLM(LLM):
|
||||
completion_kwargs["metadata"] = metadata
|
||||
|
||||
try:
|
||||
response = litellm.completion(
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
# model="openai/gpt-4",
|
||||
@@ -314,9 +311,24 @@ class LitellmLLM(LLM):
|
||||
temperature=(1 if is_reasoning else self._temperature),
|
||||
timeout=timeout_override or self._timeout,
|
||||
**({"stream_options": {"include_usage": True}} if stream else {}),
|
||||
# NOTE: we can't pass parallel_tool_calls if tools are not specified
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": parallel_tool_calls} if tools else {}),
|
||||
**(
|
||||
{"parallel_tool_calls": parallel_tool_calls}
|
||||
if tools
|
||||
and self.config.model_name
|
||||
not in [
|
||||
"o3-mini",
|
||||
"o3-preview",
|
||||
"o1",
|
||||
"o1-preview",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
]
|
||||
else {}
|
||||
),
|
||||
# Anthropic Claude uses `thinking` with budget_tokens for extended thinking
|
||||
# This applies to Claude models on any provider (anthropic, vertex_ai, bedrock)
|
||||
**(
|
||||
@@ -338,7 +350,10 @@ class LitellmLLM(LLM):
|
||||
# (litellm maps this to thinking_level for Gemini 3 models)
|
||||
**(
|
||||
{"reasoning_effort": OPENAI_REASONING_EFFORT[reasoning_effort]}
|
||||
if is_reasoning and "claude" not in self.config.model_name.lower()
|
||||
if reasoning_effort
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and is_reasoning
|
||||
and "claude" not in self.config.model_name.lower()
|
||||
else {}
|
||||
),
|
||||
**(
|
||||
@@ -349,7 +364,6 @@ class LitellmLLM(LLM):
|
||||
**({self._max_token_param: max_tokens} if max_tokens else {}),
|
||||
**completion_kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
|
||||
self._record_error(prompt, e)
|
||||
|
||||
@@ -54,6 +54,12 @@ def _build_provider_extra_headers(
|
||||
return {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
return llms[0]
|
||||
|
||||
|
||||
def get_llm_config_for_persona(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
@@ -101,16 +107,16 @@ def get_llm_config_for_persona(
|
||||
)
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
user: User | None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLM")
|
||||
return get_default_llm()
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
provider_name_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
@@ -118,7 +124,7 @@ def get_llm_for_persona(
|
||||
|
||||
provider_name = provider_name_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
return get_default_llm(
|
||||
return get_default_llms(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
@@ -147,7 +153,7 @@ def get_llm_for_persona(
|
||||
getattr(persona_model, "id", None),
|
||||
provider_model.name,
|
||||
)
|
||||
return get_default_llm(
|
||||
return get_default_llms(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
@@ -156,24 +162,30 @@ def get_llm_for_persona(
|
||||
llm_provider = LLMProviderView.from_model(provider_model)
|
||||
|
||||
model = model_version_override or persona.llm_model_version_override
|
||||
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
if not model:
|
||||
raise ValueError("No model name found")
|
||||
if not fast_model:
|
||||
raise ValueError("No fast model name found")
|
||||
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
max_input_tokens=get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model
|
||||
),
|
||||
)
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
max_input_tokens=get_max_input_tokens_from_llm_provider(
|
||||
llm_provider=llm_provider, model_name=model
|
||||
),
|
||||
)
|
||||
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
def get_default_llm_with_vision(
|
||||
@@ -235,7 +247,7 @@ def get_default_llm_with_vision(
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_vision_model)
|
||||
|
||||
# If no model-configurations are specified, try default model
|
||||
# If no model-configurations are specified, try default models in priority order
|
||||
if not provider.model_configurations:
|
||||
# Try default_model_name
|
||||
if provider.default_model_name and model_supports_image_input(
|
||||
@@ -243,6 +255,14 @@ def get_default_llm_with_vision(
|
||||
):
|
||||
return create_vision_llm(provider_view, provider.default_model_name)
|
||||
|
||||
# Try fast_default_model_name
|
||||
if provider.fast_default_model_name and model_supports_image_input(
|
||||
provider.fast_default_model_name, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
provider_view, provider.fast_default_model_name
|
||||
)
|
||||
|
||||
# Otherwise, if model-configurations are specified, check each model
|
||||
else:
|
||||
for model_configuration in provider.model_configurations:
|
||||
@@ -291,12 +311,12 @@ def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
|
||||
)
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
def get_default_llms(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
) -> tuple[LLM, LLM]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
@@ -304,17 +324,25 @@ def get_default_llm(
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
model_name = llm_provider.default_model_name
|
||||
fast_model_name = (
|
||||
llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
)
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
if not fast_model_name:
|
||||
raise ValueError("No fast default model name found")
|
||||
|
||||
return llm_from_provider(
|
||||
model_name=model_name,
|
||||
llm_provider=llm_provider,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return llm_from_provider(
|
||||
model_name=model,
|
||||
llm_provider=llm_provider,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
return _create_llm(model_name), _create_llm(fast_model_name)
|
||||
|
||||
|
||||
def get_llm(
|
||||
|
||||
@@ -47,6 +47,7 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
custom_config_keys: list[CustomConfigKey] | None = None
|
||||
model_configurations: list[ModelConfigurationView]
|
||||
default_model: str | None = None
|
||||
default_fast_model: str | None = None
|
||||
default_api_base: str | None = None
|
||||
# set for providers like Azure, which require a deployment name.
|
||||
deployment_name_required: bool = False
|
||||
@@ -147,6 +148,7 @@ VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.5-flash-lite"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
@@ -333,6 +335,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
OPENAI_PROVIDER_NAME
|
||||
),
|
||||
default_model="gpt-4o",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=OLLAMA_PROVIDER_NAME,
|
||||
@@ -354,6 +357,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
OLLAMA_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
default_api_base="http://127.0.0.1:11434",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
@@ -368,6 +372,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
ANTHROPIC_PROVIDER_NAME
|
||||
),
|
||||
default_model="claude-sonnet-4-5-20250929",
|
||||
default_fast_model="claude-sonnet-4-20250514",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
@@ -450,6 +455,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=VERTEXAI_PROVIDER_NAME,
|
||||
@@ -482,6 +488,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
),
|
||||
],
|
||||
default_model=VERTEXAI_DEFAULT_MODEL,
|
||||
default_fast_model=VERTEXAI_DEFAULT_FAST_MODEL,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=OPENROUTER_PROVIDER_NAME,
|
||||
@@ -495,6 +502,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
OPENROUTER_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -35,11 +35,8 @@ CLAUDE_REASONING_BUDGET_TOKENS: dict[ReasoningEffort, int] = {
|
||||
}
|
||||
|
||||
# OpenAI reasoning effort mapping (direct string values)
|
||||
# TODO this needs to be cleaned up, there is a lot of jank and unnecessary slowness
|
||||
# Also there should be auto for reasoning level which is not used here.
|
||||
OPENAI_REASONING_EFFORT: dict[ReasoningEffort | None, str] = {
|
||||
None: "medium", # Seems there is no auto mode in this version unfortunately
|
||||
ReasoningEffort.OFF: "low", # Issues with 5.2 models not supporting minimal or off with this version of litellm
|
||||
OPENAI_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.OFF: "none", # this only works for the 5 series though
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
|
||||
@@ -34,7 +34,7 @@ from onyx.configs.onyxbot_configs import (
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.onyxbot.slack.constants import FeedbackVisibility
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
@@ -141,7 +141,7 @@ def check_message_limit() -> bool:
|
||||
|
||||
|
||||
def rephrase_slack_message(msg: str) -> str:
|
||||
llm = get_default_llm(timeout=5)
|
||||
llm, _ = get_default_llms(timeout=5)
|
||||
prompt = SLACK_LANGUAGE_REPHRASE_PROMPT.format(query=msg)
|
||||
model_output = llm_response_to_string(llm.invoke(prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
@@ -6,8 +6,7 @@ from onyx.prompts.deep_research.dr_tool_prompts import THINK_TOOL_NAME
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
CLARIFICATION_PROMPT = f"""
|
||||
You are a clarification agent that runs prior to deep research. Assess whether you need to ask clarifying questions, or if the user has already provided enough information for you to start research. \
|
||||
CRITICAL - Never directly answer the user's query, you must only ask clarifying questions or call the `{GENERATE_PLAN_TOOL_NAME}` tool.
|
||||
You are a clarification agent that runs prior to deep research. Assess whether you need to ask clarifying questions, or if the user has already provided enough information for you to start research. Clarifications are generally helpful.
|
||||
|
||||
If the user query is already very detailed or lengthy (more than 3 sentences), do not ask for clarification and instead call the `{GENERATE_PLAN_TOOL_NAME}` tool.
|
||||
|
||||
@@ -32,7 +31,7 @@ Focus on providing a thorough research of the user's query over being helpful.
|
||||
|
||||
For context, the date is {current_datetime}.
|
||||
|
||||
The research plan should be formatted as a numbered list of steps and have 6 or less individual steps.
|
||||
The research plan should be formatted as a numbered list of steps and have less than 7 individual steps.
|
||||
|
||||
Each step should be a standalone exploration question or topic that can be researched independently but may build on previous steps.
|
||||
|
||||
@@ -110,7 +109,7 @@ Provide inline citations in the format [1], [2], [3], etc. based on the citation
|
||||
"""
|
||||
|
||||
|
||||
USER_FINAL_REPORT_QUERY = f"""
|
||||
USER_FINAL_REPORT_QUERY = """
|
||||
Provide a comprehensive answer to my previous query. CRITICAL: be as detailed as possible, stay on topic, and provide clear organization in your response.
|
||||
|
||||
Ignore the format styles of the intermediate {RESEARCH_AGENT_TOOL_NAME} reports, those are not end user facing and different from your task.
|
||||
|
||||
@@ -60,8 +60,4 @@ GENERATE_IMAGE_GUIDANCE = """
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
"""
|
||||
|
||||
TOOL_CALL_FAILURE_PROMPT = """
|
||||
LLM attempted to call a tool but failed. Most likely the tool name was misspelled.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -17,7 +17,7 @@ def get_answer_validity(
|
||||
return False
|
||||
return True # If something is wrong, let's not toss away the answer
|
||||
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
|
||||
prompt = ANSWER_VALIDITY_PROMPT.format(user_query=query, llm_answer=answer)
|
||||
model_output = llm_response_to_string(llm.invoke(prompt))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
@@ -224,11 +224,11 @@ def keyword_query_expansion(
|
||||
|
||||
|
||||
def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
llm = get_default_llm(timeout=5)
|
||||
_, fast_llm = get_default_llms(timeout=5)
|
||||
|
||||
prompt = LANGUAGE_REPHRASE_PROMPT.format(query=query, target_language=language)
|
||||
model_output = llm_response_to_string(
|
||||
llm.invoke(prompt, reasoning_effort=ReasoningEffort.OFF)
|
||||
fast_llm.invoke(prompt, reasoning_effort=ReasoningEffort.OFF)
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.db.models import StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
|
||||
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
|
||||
@@ -74,7 +74,7 @@ def generate_start_message_prompts(
|
||||
categories: List[str | None],
|
||||
chunk_contents: str,
|
||||
supports_structured_output: bool,
|
||||
llm: Any,
|
||||
fast_llm: Any,
|
||||
) -> List[FunctionCall]:
|
||||
"""
|
||||
Generates the list of FunctionCall objects for starter message generation.
|
||||
@@ -99,7 +99,7 @@ def generate_start_message_prompts(
|
||||
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
llm.invoke,
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt,),
|
||||
)
|
||||
)
|
||||
@@ -119,12 +119,12 @@ def generate_starter_messages(
|
||||
Generates starter messages by first obtaining categories and then generating messages for each category.
|
||||
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
|
||||
"""
|
||||
llm = get_default_llm(temperature=0.5)
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
provider = llm.config.model_provider
|
||||
model = llm.config.model_name
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
|
||||
supports_structured_output = (
|
||||
@@ -142,7 +142,7 @@ def generate_starter_messages(
|
||||
num_categories=generation_count,
|
||||
)
|
||||
|
||||
category_response = llm.invoke(category_generation_prompt)
|
||||
category_response = fast_llm.invoke(category_generation_prompt)
|
||||
response_content = llm_response_to_string(category_response)
|
||||
categories = parse_categories(response_content)
|
||||
|
||||
@@ -179,7 +179,7 @@ def generate_starter_messages(
|
||||
categories,
|
||||
chunk_contents,
|
||||
supports_structured_output,
|
||||
llm,
|
||||
fast_llm,
|
||||
)
|
||||
|
||||
# Run LLM calls in parallel
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import mimetypes
|
||||
@@ -9,8 +10,6 @@ from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
@@ -24,9 +23,6 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
@@ -111,15 +107,13 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.server.documents.models import ConnectorFileInfo
|
||||
from onyx.server.documents.models import ConnectorFilesResponse
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLite
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLiteResponse
|
||||
from onyx.server.documents.models import ConnectorSnapshot
|
||||
@@ -142,8 +136,9 @@ from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.documents.models import SourceSummary
|
||||
from onyx.server.federated.models import FederatedConnectorStatus
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import CallableProtocol
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -461,6 +456,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
|
||||
# Skip directories and known macOS metadata entries
|
||||
def should_process_file(file_path: str) -> bool:
|
||||
@@ -474,10 +472,6 @@ def upload_files(
|
||||
file_store = get_default_file_store()
|
||||
seen_zip = False
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
logger.warning("File has no filename, skipping")
|
||||
continue
|
||||
|
||||
if is_zip_file(file):
|
||||
if seen_zip:
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
@@ -507,6 +501,24 @@ def upload_files(
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
continue
|
||||
|
||||
# For mypy, actual check happens at start of function
|
||||
assert file.filename is not None
|
||||
|
||||
# Special handling for doc files - only store the plaintext version
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(file.file, file.filename or "")
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type="text/plain",
|
||||
)
|
||||
deduped_file_paths.append(text_file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
# Default handling for all other file types
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
@@ -525,17 +537,6 @@ def upload_files(
|
||||
)
|
||||
|
||||
|
||||
def _normalize_file_names_for_backwards_compatibility(
|
||||
file_locations: list[str], file_names: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Ensures file_names list is the same length as file_locations for backwards compatibility.
|
||||
In legacy data, file_names might not exist or be shorter than file_locations.
|
||||
If file_names is shorter, pads it with corresponding file_locations values.
|
||||
"""
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload")
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -544,229 +545,6 @@ def upload_files_api(
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
|
||||
|
||||
@router.get("/admin/connector/{connector_id}/files")
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector not found")
|
||||
|
||||
if connector.source != DocumentSource.FILE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
# Normalize file_names for backwards compatibility with legacy data
|
||||
file_names = _normalize_file_names_for_backwards_compatibility(
|
||||
file_locations, file_names
|
||||
)
|
||||
|
||||
file_store = get_default_file_store()
|
||||
files = []
|
||||
|
||||
for file_id, file_name in zip(file_locations, file_names):
|
||||
try:
|
||||
file_record = file_store.read_file_record(file_id)
|
||||
file_size = None
|
||||
upload_date = None
|
||||
if file_record:
|
||||
file_size = file_store.get_file_size(file_id)
|
||||
upload_date = (
|
||||
file_record.created_at.isoformat()
|
||||
if file_record.created_at
|
||||
else None
|
||||
)
|
||||
files.append(
|
||||
ConnectorFileInfo(
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
file_size=file_size,
|
||||
upload_date=upload_date,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading file record for {file_id}: {e}")
|
||||
# Include file with basic info even if record fetch fails
|
||||
files.append(
|
||||
ConnectorFileInfo(
|
||||
file_id=file_id,
|
||||
file_name=file_name,
|
||||
)
|
||||
)
|
||||
|
||||
return ConnectorFilesResponse(files=files)
|
||||
|
||||
|
||||
@router.post("/admin/connector/{connector_id}/files/update")
|
||||
def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
Update files in a connector by adding new files and/or removing existing ones.
|
||||
This is an atomic operation that validates, updates the connector config, and triggers indexing.
|
||||
"""
|
||||
files = files or []
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector not found")
|
||||
|
||||
if connector.source != DocumentSource.FILE:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
file_ids_list = json.loads(file_ids_to_remove)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid file_ids_to_remove format")
|
||||
|
||||
if not isinstance(file_ids_list, list):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="file_ids_to_remove must be a JSON-encoded list",
|
||||
)
|
||||
|
||||
# Get current connector config
|
||||
current_config = connector.connector_specific_config
|
||||
current_file_locations = current_config.get("file_locations", [])
|
||||
current_file_names = current_config.get("file_names", [])
|
||||
current_zip_metadata = current_config.get("zip_metadata", {})
|
||||
|
||||
# Upload new files if any
|
||||
new_file_paths = []
|
||||
new_file_names_list = []
|
||||
new_zip_metadata = {}
|
||||
|
||||
if files and len(files) > 0:
|
||||
upload_response = upload_files(files, FileOrigin.CONNECTOR)
|
||||
new_file_paths = upload_response.file_paths
|
||||
new_file_names_list = upload_response.file_names
|
||||
new_zip_metadata = upload_response.zip_metadata
|
||||
|
||||
# Remove specified files
|
||||
files_to_remove_set = set(file_ids_list)
|
||||
|
||||
# Normalize file_names for backwards compatibility with legacy data
|
||||
current_file_names = _normalize_file_names_for_backwards_compatibility(
|
||||
current_file_locations, current_file_names
|
||||
)
|
||||
|
||||
remaining_file_locations = []
|
||||
remaining_file_names = []
|
||||
|
||||
for file_id, file_name in zip(current_file_locations, current_file_names):
|
||||
if file_id not in files_to_remove_set:
|
||||
remaining_file_locations.append(file_id)
|
||||
remaining_file_names.append(file_name)
|
||||
|
||||
# Combine remaining files with new files
|
||||
final_file_locations = remaining_file_locations + new_file_paths
|
||||
final_file_names = remaining_file_names + new_file_names_list
|
||||
|
||||
# Validate that at least one file remains
|
||||
if not final_file_locations:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot remove all files from connector. At least one file must remain.",
|
||||
)
|
||||
|
||||
# Update zip metadata
|
||||
final_zip_metadata = {
|
||||
key: value
|
||||
for key, value in current_zip_metadata.items()
|
||||
if key not in files_to_remove_set
|
||||
}
|
||||
final_zip_metadata.update(new_zip_metadata)
|
||||
|
||||
# Update connector config
|
||||
updated_config = {
|
||||
**current_config,
|
||||
"file_locations": final_file_locations,
|
||||
"file_names": final_file_names,
|
||||
"zip_metadata": final_zip_metadata,
|
||||
}
|
||||
|
||||
connector_base = ConnectorBase(
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=updated_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
indexing_start=connector.indexing_start,
|
||||
)
|
||||
|
||||
updated_connector = update_connector(connector_id, connector_base, db_session)
|
||||
if updated_connector is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to update connector configuration"
|
||||
)
|
||||
|
||||
# Trigger re-indexing for new files and pruning for removed files
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# If files were added, mark for UPDATE indexing (only new docs)
|
||||
if new_file_paths:
|
||||
mark_ccpair_with_indexing_trigger(
|
||||
cc_pair.id, IndexingMode.UPDATE, db_session
|
||||
)
|
||||
|
||||
# Send task to check for indexing immediately
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Marked cc_pair {cc_pair.id} for UPDATE indexing (new files) for connector {connector_id}"
|
||||
)
|
||||
|
||||
# If files were removed, trigger pruning immediately
|
||||
if file_ids_list:
|
||||
r = get_redis_client()
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
client_app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if payload_id:
|
||||
logger.info(
|
||||
f"Triggered pruning for cc_pair {cc_pair.id} (removed files) for connector "
|
||||
f"{connector_id}, payload_id={payload_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to trigger pruning for cc_pair {cc_pair.id} (removed files) for connector {connector_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to trigger re-indexing after file update: {e}")
|
||||
|
||||
return FileUploadResponse(
|
||||
file_paths=final_file_locations,
|
||||
file_names=final_file_names,
|
||||
zip_metadata=final_zip_metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
def get_connectors_by_credential(
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
@@ -1156,10 +934,12 @@ def get_connector_indexing_status(
|
||||
].total_docs_indexed += connector_status.docs_indexed
|
||||
|
||||
# Track admin page visit for analytics
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.VISITED_ADMIN_PAGE,
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.VISITED_ADMIN_PAGE,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Group statuses by source for pagination
|
||||
@@ -1371,10 +1151,12 @@ def create_connector_from_model(
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return connector_response
|
||||
@@ -1450,10 +1232,12 @@ def create_connector_with_mock_credential(
|
||||
f"cc_pair={response.data}"
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -571,17 +571,6 @@ class FileUploadResponse(BaseModel):
|
||||
zip_metadata: dict[str, Any]
|
||||
|
||||
|
||||
class ConnectorFileInfo(BaseModel):
|
||||
file_id: str
|
||||
file_name: str
|
||||
file_size: int | None = None
|
||||
upload_date: str | None = None
|
||||
|
||||
|
||||
class ConnectorFilesResponse(BaseModel):
|
||||
files: list[ConnectorFileInfo]
|
||||
|
||||
|
||||
class ObjectCreationIdResponse(BaseModel):
|
||||
id: int
|
||||
credential: CredentialSnapshot | None = None
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Awaitable
|
||||
@@ -41,7 +40,6 @@ from onyx.db.enums import MCPServerStatus
|
||||
from onyx.db.enums import MCPTransport
|
||||
from onyx.db.mcp import create_connection_config
|
||||
from onyx.db.mcp import create_mcp_server__no_commit
|
||||
from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
|
||||
from onyx.db.mcp import delete_connection_config
|
||||
from onyx.db.mcp import delete_mcp_server
|
||||
from onyx.db.mcp import delete_user_connection_configs_for_server
|
||||
@@ -1014,7 +1012,6 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
is_authenticated=is_authenticated,
|
||||
user_authenticated=user_authenticated,
|
||||
status=db_server.status,
|
||||
last_refreshed_at=db_server.last_refreshed_at,
|
||||
tool_count=tool_count,
|
||||
auth_template=auth_template,
|
||||
user_credentials=user_credentials,
|
||||
@@ -1136,7 +1133,6 @@ def get_mcp_server_tools_snapshots(
|
||||
server_id=server_id,
|
||||
db_session=db,
|
||||
status=MCPServerStatus.CONNECTED,
|
||||
last_refreshed_at=datetime.datetime.now(datetime.timezone.utc),
|
||||
)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
@@ -1209,7 +1205,7 @@ def _upsert_db_tools(
|
||||
db_session=db,
|
||||
passthrough_auth=False,
|
||||
mcp_server_id=mcp_server_id,
|
||||
enabled=True,
|
||||
enabled=False,
|
||||
)
|
||||
new_tool.display_name = display_name
|
||||
new_tool.mcp_input_schema = input_schema
|
||||
@@ -1387,21 +1383,7 @@ def _upsert_mcp_server(
|
||||
)
|
||||
|
||||
# Cleanup: Delete existing connection configs
|
||||
# If the auth type is OAUTH, delete all user connection configs
|
||||
# If the auth type is API_TOKEN, delete the admin connection config and the admin user connection configs
|
||||
if (
|
||||
changing_connection_config
|
||||
and mcp_server.admin_connection_config_id
|
||||
and request.auth_type == MCPAuthenticationType.OAUTH
|
||||
):
|
||||
delete_all_user_connection_configs_for_server_no_commit(
|
||||
mcp_server.id, db_session
|
||||
)
|
||||
elif (
|
||||
changing_connection_config
|
||||
and mcp_server.admin_connection_config_id
|
||||
and request.auth_type == MCPAuthenticationType.API_TOKEN
|
||||
):
|
||||
if changing_connection_config and mcp_server.admin_connection_config_id:
|
||||
delete_connection_config(mcp_server.admin_connection_config_id, db_session)
|
||||
if user and user.email:
|
||||
delete_user_connection_configs_for_server(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
@@ -298,7 +297,6 @@ class MCPServer(BaseModel):
|
||||
is_authenticated: bool
|
||||
user_authenticated: Optional[bool] = None
|
||||
status: MCPServerStatus
|
||||
last_refreshed_at: Optional[datetime.datetime] = None
|
||||
tool_count: int = Field(
|
||||
default=0, description="Number of tools associated with this server"
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ from onyx.server.manage.llm.api import get_valid_model_names_for_persona
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -282,10 +282,12 @@ def create_persona(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.CREATED_ASSISTANT,
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_ASSISTANT,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return persona_snapshot
|
||||
|
||||
@@ -8,10 +8,11 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -123,7 +124,7 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
@@ -154,7 +155,7 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
|
||||
if extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
try:
|
||||
token_count = estimate_image_tokens_for_upload(upload)
|
||||
except (UnidentifiedImageError, OSError) as e:
|
||||
@@ -172,7 +173,10 @@ def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
continue
|
||||
|
||||
# Otherwise, handle as text/document: extract text and count tokens
|
||||
elif extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
if (
|
||||
extension in ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
and extension not in ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
):
|
||||
text_content = extract_file_text(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.db.models import User
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.server.manage.models import BoostDoc
|
||||
@@ -121,7 +121,7 @@ def validate_existing_genai_api_key(
|
||||
pass
|
||||
|
||||
try:
|
||||
llm = get_default_llm(timeout=10)
|
||||
llm, __ = get_default_llms(timeout=10)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="LLM not setup")
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
@@ -31,13 +32,14 @@ from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.llm import validate_persona_ids_exist
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import user_can_access_persona
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.factory import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import get_bedrock_token_limit
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
@@ -62,6 +64,7 @@ from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
from onyx.server.manage.llm.utils import strip_openrouter_vendor_prefix
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -89,7 +92,7 @@ def test_llm_configuration(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Test LLM configuration settings"""
|
||||
"""Test regular llm and fast llm settings"""
|
||||
|
||||
# the api key is sanitized if we are testing a provider already in the system
|
||||
|
||||
@@ -120,10 +123,36 @@ def test_llm_configuration(
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
error_msg = test_llm(llm)
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
|
||||
if (
|
||||
test_llm_request.fast_default_model_name
|
||||
and test_llm_request.fast_default_model_name
|
||||
!= test_llm_request.default_model_name
|
||||
):
|
||||
fast_llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.fast_default_model_name,
|
||||
api_key=test_api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
deployment_name=test_llm_request.deployment_name,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
functions_with_args.append((test_llm, (fast_llm,)))
|
||||
|
||||
if error_msg:
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
error = parallel_results[0] or (
|
||||
parallel_results[1] if len(parallel_results) > 1 else None
|
||||
)
|
||||
|
||||
if error:
|
||||
client_error_msg, _error_code, _is_retryable = litellm_exception_to_error_msg(
|
||||
error, llm, fallback_to_error_msg=True
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=client_error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -131,12 +160,21 @@ def test_default_provider(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, fast_llm = get_default_llms()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(test_llm, (llm,)),
|
||||
(test_llm, (fast_llm,)),
|
||||
]
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
error = parallel_results[0] or (
|
||||
parallel_results[1] if len(parallel_results) > 1 else None
|
||||
)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=str(error))
|
||||
|
||||
@@ -216,18 +254,34 @@ def put_llm_provider(
|
||||
llm_provider_upsert_request.personas = deduplicated_personas
|
||||
|
||||
default_model_found = False
|
||||
default_fast_model_found = False
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
if model_configuration.name == llm_provider_upsert_request.default_model_name:
|
||||
model_configuration.is_visible = True
|
||||
default_model_found = True
|
||||
if (
|
||||
llm_provider_upsert_request.fast_default_model_name
|
||||
and llm_provider_upsert_request.fast_default_model_name
|
||||
== model_configuration.name
|
||||
):
|
||||
model_configuration.is_visible = True
|
||||
default_fast_model_found = True
|
||||
|
||||
default_inserts = set()
|
||||
if not default_model_found:
|
||||
llm_provider_upsert_request.model_configurations.append(
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=llm_provider_upsert_request.default_model_name, is_visible=True
|
||||
)
|
||||
)
|
||||
default_inserts.add(llm_provider_upsert_request.default_model_name)
|
||||
|
||||
if (
|
||||
llm_provider_upsert_request.fast_default_model_name
|
||||
and not default_fast_model_found
|
||||
):
|
||||
default_inserts.add(llm_provider_upsert_request.fast_default_model_name)
|
||||
|
||||
llm_provider_upsert_request.model_configurations.extend(
|
||||
ModelConfigurationUpsertRequest(name=name, is_visible=True)
|
||||
for name in default_inserts
|
||||
)
|
||||
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
|
||||
@@ -32,6 +32,7 @@ class TestLLMRequest(BaseModel):
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None = None
|
||||
deployment_name: str | None = None
|
||||
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"]
|
||||
@@ -54,6 +55,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
provider: str
|
||||
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
@@ -73,6 +75,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
provider=provider,
|
||||
provider_display_name=get_provider_display_name(provider),
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
@@ -90,6 +93,7 @@ class LLMProvider(BaseModel):
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None = None
|
||||
is_public: bool = True
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
personas: list[int] = Field(default_factory=list)
|
||||
@@ -146,6 +150,7 @@ class LLMProviderView(LLMProvider):
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
|
||||
@@ -30,7 +30,7 @@ from onyx.server.manage.validate_tokens import validate_app_token
|
||||
from onyx.server.manage.validate_tokens import validate_bot_token
|
||||
from onyx.server.manage.validate_tokens import validate_user_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
SLACK_API_CHANNELS_PER_PAGE = 100
|
||||
@@ -274,10 +274,12 @@ def create_bot(
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.CREATED_ONYX_BOT,
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_ONYX_BOT,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return SlackBot.from_model(slack_bot_model)
|
||||
|
||||
@@ -203,17 +203,10 @@ def list_accepted_users(
|
||||
@router.get("/manage/users/invited")
|
||||
def list_invited_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InvitedUserSnapshot]:
|
||||
invited_emails = get_invited_users()
|
||||
|
||||
# Filter out users who are already active in the system
|
||||
active_user_emails = {user.email for user in get_all_users(db_session)}
|
||||
filtered_invited_emails = [
|
||||
email for email in invited_emails if email not in active_user_emails
|
||||
]
|
||||
|
||||
return [InvitedUserSnapshot(email=email) for email in filtered_invited_emails]
|
||||
return [InvitedUserSnapshot(email=email) for email in invited_emails]
|
||||
|
||||
|
||||
@router.get("/manage/users")
|
||||
@@ -238,13 +231,6 @@ def list_all_users(
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
invited_emails = get_invited_users()
|
||||
|
||||
# Filter out users who are already active (either accepted or slack users)
|
||||
all_active_emails = accepted_emails | slack_users_emails
|
||||
invited_emails = [
|
||||
email for email in invited_emails if email not in all_active_emails
|
||||
]
|
||||
|
||||
if q:
|
||||
invited_emails = [
|
||||
email for email in invited_emails if re.search(r"{}".format(q), email, re.I)
|
||||
|
||||
@@ -14,10 +14,8 @@ from onyx.db.web_search import deactivate_web_search_provider
|
||||
from onyx.db.web_search import delete_web_content_provider
|
||||
from onyx.db.web_search import delete_web_search_provider
|
||||
from onyx.db.web_search import fetch_web_content_provider_by_name
|
||||
from onyx.db.web_search import fetch_web_content_provider_by_type
|
||||
from onyx.db.web_search import fetch_web_content_providers
|
||||
from onyx.db.web_search import fetch_web_search_provider_by_name
|
||||
from onyx.db.web_search import fetch_web_search_provider_by_type
|
||||
from onyx.db.web_search import fetch_web_search_providers
|
||||
from onyx.db.web_search import set_active_web_content_provider
|
||||
from onyx.db.web_search import set_active_web_search_provider
|
||||
@@ -149,33 +147,11 @@ def deactivate_search_provider(
|
||||
def test_search_provider(
|
||||
request: WebSearchProviderTestRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
provider_requires_api_key = request.provider_type != WebSearchProviderType.SEARXNG
|
||||
|
||||
# Determine which API key to use
|
||||
api_key = request.api_key
|
||||
if request.use_stored_key and provider_requires_api_key:
|
||||
existing_provider = fetch_web_search_provider_by_type(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key
|
||||
|
||||
if provider_requires_api_key and not api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
provider = build_search_provider_from_config(
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key or "",
|
||||
api_key=request.api_key,
|
||||
config=request.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -315,31 +291,11 @@ def deactivate_content_provider(
|
||||
def test_content_provider(
|
||||
request: WebContentProviderTestRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
# Determine which API key to use
|
||||
api_key = request.api_key
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_web_content_provider_by_type(
|
||||
request.provider_type, db_session
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
try:
|
||||
provider = build_content_provider_from_config(
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key=request.api_key,
|
||||
config=request.config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -60,25 +60,11 @@ class WebContentProviderUpsertRequest(BaseModel):
|
||||
|
||||
class WebSearchProviderTestRequest(BaseModel):
|
||||
provider_type: WebSearchProviderType
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type instead of api_key.",
|
||||
)
|
||||
api_key: str
|
||||
config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class WebContentProviderTestRequest(BaseModel):
|
||||
provider_type: WebContentProviderType
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type instead of api_key.",
|
||||
)
|
||||
api_key: str
|
||||
config: WebContentProviderConfig
|
||||
|
||||
@@ -44,6 +44,7 @@ from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
@@ -54,9 +55,9 @@ from onyx.db.projects import check_project_ownership
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.secondary_llm_flows.chat_session_naming import (
|
||||
@@ -88,7 +89,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -116,7 +117,7 @@ def _get_available_tokens_for_persona(
|
||||
- default_reserved_tokens
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=persona, user=user)
|
||||
llm, _ = get_llms_for_persona(persona=persona, user=user)
|
||||
token_counter = get_llm_token_counter(llm)
|
||||
|
||||
system_prompt = get_default_base_system_prompt(db_session)
|
||||
@@ -353,7 +354,7 @@ def rename_chat_session(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
llm = get_default_llm(
|
||||
llm, _ = get_default_llms(
|
||||
additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
)
|
||||
@@ -450,11 +451,14 @@ def handle_new_chat_message(
|
||||
if not chat_message_req.message and not chat_message_req.use_existing_user_message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event=MilestoneRecordType.RAN_QUERY,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.RAN_QUERY,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
@@ -666,7 +670,7 @@ def seed_chat(
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session.id, db_session=db_session
|
||||
)
|
||||
llm = get_llm_for_persona(
|
||||
llm, _fast_llm = get_llms_for_persona(
|
||||
persona=new_chat_session.persona,
|
||||
user=user,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from onyx.file_processing.file_types import OnyxMimeTypes
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.utils.file_types import UploadMimeTypes
|
||||
|
||||
|
||||
def mime_type_to_chat_file_type(mime_type: str | None) -> ChatFileType:
|
||||
if mime_type is None:
|
||||
return ChatFileType.PLAIN_TEXT
|
||||
|
||||
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
|
||||
if mime_type in UploadMimeTypes.IMAGE_MIME_TYPES:
|
||||
return ChatFileType.IMAGE
|
||||
|
||||
if mime_type in OnyxMimeTypes.CSV_MIME_TYPES:
|
||||
if mime_type in UploadMimeTypes.CSV_MIME_TYPES:
|
||||
return ChatFileType.CSV
|
||||
|
||||
if mime_type in OnyxMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
if mime_type in UploadMimeTypes.DOCUMENT_MIME_TYPES:
|
||||
return ChatFileType.DOC
|
||||
|
||||
return ChatFileType.PLAIN_TEXT
|
||||
|
||||
@@ -48,8 +48,9 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import find_tags
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import AdminSearchRequest
|
||||
from onyx.server.query_and_chat.models import AdminSearchResponse
|
||||
@@ -107,7 +108,7 @@ def handle_search_request(
|
||||
query = search_request.message
|
||||
logger.notice(f"Received document search query: {query}")
|
||||
|
||||
llm = get_default_llm()
|
||||
llm, __name__ = get_default_llms()
|
||||
pagination_limit, pagination_offset = _normalize_pagination(
|
||||
limit=search_request.retrieval_options.limit,
|
||||
offset=search_request.retrieval_options.offset,
|
||||
@@ -228,7 +229,7 @@ def get_answer_stream(
|
||||
is_for_edit=False,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=persona_info, user=user)
|
||||
llm = get_main_llm_from_tuple(get_llms_for_persona(persona=persona_info, user=user))
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
|
||||
@@ -124,14 +124,13 @@ def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packe
|
||||
|
||||
|
||||
def create_image_generation_packets(
|
||||
images: list[GeneratedImage], turn_index: int, tab_index: int = 0
|
||||
images: list[GeneratedImage], turn_index: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
)
|
||||
@@ -139,12 +138,11 @@ def create_image_generation_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=ImageGenerationFinal(images=images),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -153,7 +151,6 @@ def create_custom_tool_packets(
|
||||
tool_name: str,
|
||||
response_type: str,
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
data: dict | list | str | int | float | bool | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
) -> list[Packet]:
|
||||
@@ -162,7 +159,6 @@ def create_custom_tool_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolStart(tool_name=tool_name),
|
||||
)
|
||||
)
|
||||
@@ -170,7 +166,6 @@ def create_custom_tool_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=CustomToolDelta(
|
||||
tool_name=tool_name,
|
||||
response_type=response_type,
|
||||
@@ -180,7 +175,7 @@ def create_custom_tool_packets(
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -189,14 +184,12 @@ def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
# Emit start packet
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlStart(),
|
||||
)
|
||||
)
|
||||
@@ -204,7 +197,6 @@ def create_fetch_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlUrls(urls=urls),
|
||||
)
|
||||
)
|
||||
@@ -212,13 +204,12 @@ def create_fetch_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=OpenUrlDocuments(
|
||||
documents=[SearchDoc(**doc.model_dump()) for doc in fetch_docs]
|
||||
),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
@@ -227,14 +218,12 @@ def create_search_packets(
|
||||
search_docs: list[SavedSearchDoc],
|
||||
is_internet_search: bool,
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolStart(
|
||||
is_internet_search=is_internet_search,
|
||||
),
|
||||
@@ -246,7 +235,6 @@ def create_search_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolQueriesDelta(queries=search_queries),
|
||||
),
|
||||
)
|
||||
@@ -259,7 +247,6 @@ def create_search_packets(
|
||||
packets.append(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=[
|
||||
SearchDoc(**doc.model_dump()) for doc in sorted_search_docs
|
||||
@@ -268,7 +255,7 @@ def create_search_packets(
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(turn_index=turn_index, tab_index=tab_index, obj=SectionEnd()))
|
||||
packets.append(Packet(turn_index=turn_index, obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
@@ -324,7 +311,6 @@ def translate_assistant_message_to_packets(
|
||||
is_internet_search=tool.in_code_tool_id
|
||||
== WebSearchTool.__name__,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -338,12 +324,7 @@ def translate_assistant_message_to_packets(
|
||||
list[str], tool_call.tool_call_arguments.get("urls", [])
|
||||
)
|
||||
packet_list.extend(
|
||||
create_fetch_packets(
|
||||
fetch_docs,
|
||||
urls,
|
||||
turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
)
|
||||
create_fetch_packets(fetch_docs, urls, turn_num)
|
||||
)
|
||||
|
||||
elif tool.in_code_tool_id == ImageGenerationTool.__name__:
|
||||
@@ -353,9 +334,7 @@ def translate_assistant_message_to_packets(
|
||||
for img in tool_call.generated_images
|
||||
]
|
||||
packet_list.extend(
|
||||
create_image_generation_packets(
|
||||
images, turn_num, tab_index=tool_call.tab_index
|
||||
)
|
||||
create_image_generation_packets(images, turn_num)
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -365,7 +344,6 @@ def translate_assistant_message_to_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type="text",
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
data=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -36,16 +36,15 @@ class StreamingType(Enum):
|
||||
|
||||
DEEP_RESEARCH_PLAN_START = "deep_research_plan_start"
|
||||
DEEP_RESEARCH_PLAN_DELTA = "deep_research_plan_delta"
|
||||
RESEARCH_AGENT_START = "research_agent_start"
|
||||
|
||||
|
||||
class BaseObj(BaseModel):
|
||||
type: str = ""
|
||||
|
||||
|
||||
################################################
|
||||
# Reasoning Packets
|
||||
################################################
|
||||
"""Reasoning Packets"""
|
||||
|
||||
|
||||
# Tells the frontend to display the reasoning block
|
||||
class ReasoningStart(BaseObj):
|
||||
type: Literal["reasoning_start"] = StreamingType.REASONING_START.value
|
||||
@@ -62,9 +61,9 @@ class ReasoningDone(BaseObj):
|
||||
type: Literal["reasoning_done"] = StreamingType.REASONING_DONE.value
|
||||
|
||||
|
||||
################################################
|
||||
# Final Agent Response Packets
|
||||
################################################
|
||||
"""Final Agent Response Packets"""
|
||||
|
||||
|
||||
# Start of the final answer
|
||||
class AgentResponseStart(BaseObj):
|
||||
type: Literal["message_start"] = StreamingType.MESSAGE_START.value
|
||||
@@ -91,9 +90,9 @@ class CitationInfo(BaseObj):
|
||||
document_id: str
|
||||
|
||||
|
||||
################################################
|
||||
# Control Packets
|
||||
################################################
|
||||
"""Control Packets"""
|
||||
|
||||
|
||||
# This one isn't strictly necessary, remove in the future
|
||||
class SectionEnd(BaseObj):
|
||||
type: Literal["section_end"] = "section_end"
|
||||
@@ -110,9 +109,9 @@ class OverallStop(BaseObj):
|
||||
type: Literal["stop"] = StreamingType.STOP.value
|
||||
|
||||
|
||||
################################################
|
||||
# Tool Packets
|
||||
################################################
|
||||
"""Tool Packets"""
|
||||
|
||||
|
||||
# Search tool is called and the UI block needs to start
|
||||
class SearchToolStart(BaseObj):
|
||||
type: Literal["search_tool_start"] = StreamingType.SEARCH_TOOL_START.value
|
||||
@@ -226,9 +225,6 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
################################################
|
||||
# Deep Research Packets
|
||||
################################################
|
||||
class DeepResearchPlanStart(BaseObj):
|
||||
type: Literal["deep_research_plan_start"] = (
|
||||
StreamingType.DEEP_RESEARCH_PLAN_START.value
|
||||
@@ -243,14 +239,8 @@ class DeepResearchPlanDelta(BaseObj):
|
||||
content: str
|
||||
|
||||
|
||||
class ResearchAgentStart(BaseObj):
|
||||
type: Literal["research_agent_start"] = StreamingType.RESEARCH_AGENT_START.value
|
||||
research_task: str
|
||||
"""Packet"""
|
||||
|
||||
|
||||
################################################
|
||||
# Packet Object
|
||||
################################################
|
||||
# Discriminated union of all possible packet object types
|
||||
PacketObj = Union[
|
||||
# Agent Response Packets
|
||||
@@ -284,24 +274,15 @@ PacketObj = Union[
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
ResearchAgentStart,
|
||||
]
|
||||
|
||||
|
||||
class Placement(BaseModel):
|
||||
# Which iterative block in the UI is this part of, these are ordered and smaller ones happened first
|
||||
turn_index: int
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int
|
||||
|
||||
|
||||
class Packet(BaseModel):
|
||||
turn_index: int | None
|
||||
# For parallel tool calls to preserve order of execution
|
||||
tab_index: int = 0
|
||||
# Used for tools/agents that call other tools, this currently doesn't support nested agents but can be added later
|
||||
sub_turn_index: int | None = None
|
||||
|
||||
obj: Annotated[PacketObj, Field(discriminator="type")]
|
||||
|
||||
|
||||
# This is for replaying it back from the DB to the frontend
|
||||
class EndStepPacketList(BaseModel):
|
||||
turn_index: int
|
||||
packet_list: list[Packet]
|
||||
|
||||
@@ -15,10 +15,9 @@ class PageType(str, Enum):
|
||||
|
||||
|
||||
class ApplicationStatus(str, Enum):
|
||||
ACTIVE = "active"
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GRACE_PERIOD = "grace_period"
|
||||
GATED_ACCESS = "gated_access"
|
||||
ACTIVE = "active"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.configs.constants import KV_SEARCH_SETTINGS
|
||||
from onyx.configs.embedding_configs import SUPPORTED_EMBEDDING_MODELS
|
||||
from onyx.configs.embedding_configs import SupportedEmbeddingModel
|
||||
from onyx.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from onyx.configs.model_configs import GEN_AI_API_KEY
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
@@ -294,6 +295,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
name="DevEnvPresetOpenAI",
|
||||
provider="openai",
|
||||
@@ -302,6 +304,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
api_version=None,
|
||||
custom_config=None,
|
||||
default_model_name=llm_model,
|
||||
fast_default_model_name=fast_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
model_configurations=[
|
||||
|
||||
3
backend/onyx/tools/built_in_tools_v2.py
Normal file
3
backend/onyx/tools/built_in_tools_v2.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
BUILT_IN_TOOL_MAP_V2: dict[str, Any] = {}
|
||||
@@ -1,432 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.deep_research.dr_mock_tools import (
|
||||
get_research_agent_additional_tool_definitions,
|
||||
)
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TASK_KEY
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
|
||||
from onyx.deep_research.models import ResearchAgentCallResult
|
||||
from onyx.deep_research.utils import check_special_tool_calls
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.prompts.deep_research.dr_tool_prompts import OPEN_URLS_TOOL_DESCRIPTION
|
||||
from onyx.prompts.deep_research.dr_tool_prompts import (
|
||||
OPEN_URLS_TOOL_DESCRIPTION_REASONING,
|
||||
)
|
||||
from onyx.prompts.deep_research.dr_tool_prompts import WEB_SEARCH_TOOL_DESCRIPTION
|
||||
from onyx.prompts.deep_research.research_agent import RESEARCH_AGENT_PROMPT
|
||||
from onyx.prompts.deep_research.research_agent import RESEARCH_AGENT_PROMPT_REASONING
|
||||
from onyx.prompts.deep_research.research_agent import RESEARCH_REPORT_PROMPT
|
||||
from onyx.prompts.deep_research.research_agent import USER_REPORT_QUERY
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
from onyx.tools.utils import generate_tools_description
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
RESEARCH_CYCLE_CAP = 3
|
||||
|
||||
|
||||
def generate_intermediate_report(
|
||||
research_topic: str,
|
||||
history: list[ChatMessageSimple],
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
user_identity: LLMUserIdentity | None,
|
||||
state_container: ChatStateContainer,
|
||||
emitter: Emitter,
|
||||
placement: Placement,
|
||||
) -> str:
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_REPORT_PROMPT,
|
||||
token_count=token_counter(RESEARCH_REPORT_PROMPT),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
reminder_str = USER_REPORT_QUERY.format(research_topic=research_topic)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=reminder_str,
|
||||
token_count=token_counter(reminder_str),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
research_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=research_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
turn_index=999, # TODO
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
final_report = llm_step_result.answer
|
||||
if final_report is None:
|
||||
raise ValueError(
|
||||
f"LLM failed to generate a report for research task: {research_topic}"
|
||||
)
|
||||
|
||||
# emitter.emit(
|
||||
# Packet(
|
||||
# obj=ResearchAgentStart(research_task=research_topic),
|
||||
# placement=placement,
|
||||
# )
|
||||
# )
|
||||
|
||||
return final_report
|
||||
|
||||
|
||||
def run_research_agent_call(
|
||||
research_agent_call: ToolCallKickoff,
|
||||
tools: list[Tool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
llm: LLM,
|
||||
is_reasoning_model: bool,
|
||||
token_counter: Callable[[str], int],
|
||||
user_identity: LLMUserIdentity | None,
|
||||
) -> ResearchAgentCallResult:
|
||||
cycle_count = 0
|
||||
llm_cycle_count = 0
|
||||
current_tools = tools
|
||||
gathered_documents: list[SearchDoc] | None = None
|
||||
reasoning_cycles = 0
|
||||
just_ran_web_search = False
|
||||
|
||||
turn_index = research_agent_call.turn_index
|
||||
tab_index = research_agent_call.tab_index
|
||||
|
||||
# If this fails to parse, we can't run the loop anyway, let this one fail in that case
|
||||
research_topic = research_agent_call.tool_args[RESEARCH_AGENT_TASK_KEY]
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=0,
|
||||
obj=ResearchAgentStart(research_task=research_topic),
|
||||
)
|
||||
)
|
||||
|
||||
initial_user_message = ChatMessageSimple(
|
||||
message=research_topic,
|
||||
token_count=token_counter(research_topic),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
msg_history: list[ChatMessageSimple] = [initial_user_message]
|
||||
|
||||
citation_mapping: dict[int, str] = {}
|
||||
while cycle_count <= RESEARCH_CYCLE_CAP:
|
||||
if cycle_count == RESEARCH_CYCLE_CAP:
|
||||
current_tools = [
|
||||
tool
|
||||
for tool in tools
|
||||
if tool.name not in {SearchTool.NAME, WebSearchTool.NAME}
|
||||
]
|
||||
|
||||
tools_by_name = {tool.name: tool for tool in current_tools}
|
||||
|
||||
tools_description = generate_tools_description(current_tools)
|
||||
|
||||
internal_search_tip = (
|
||||
INTERNAL_SEARCH_GUIDANCE
|
||||
if any(isinstance(tool, SearchTool) for tool in current_tools)
|
||||
else ""
|
||||
)
|
||||
web_search_tip = (
|
||||
WEB_SEARCH_TOOL_DESCRIPTION
|
||||
if any(isinstance(tool, WebSearchTool) for tool in current_tools)
|
||||
else ""
|
||||
)
|
||||
open_urls_tip = (
|
||||
OPEN_URLS_TOOL_DESCRIPTION
|
||||
if any(isinstance(tool, OpenURLTool) for tool in current_tools)
|
||||
else ""
|
||||
)
|
||||
if is_reasoning_model and open_urls_tip:
|
||||
open_urls_tip = OPEN_URLS_TOOL_DESCRIPTION_REASONING
|
||||
|
||||
system_prompt_template = (
|
||||
RESEARCH_AGENT_PROMPT_REASONING
|
||||
if is_reasoning_model
|
||||
else RESEARCH_AGENT_PROMPT
|
||||
)
|
||||
system_prompt_str = system_prompt_template.format(
|
||||
available_tools=tools_description,
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle_count,
|
||||
optional_internal_search_tool_description=internal_search_tip,
|
||||
optional_web_search_tool_description=web_search_tip,
|
||||
optional_open_urls_tool_description=open_urls_tip,
|
||||
)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
token_count=token_counter(system_prompt_str),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
if just_ran_web_search:
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=OPEN_URL_REMINDER,
|
||||
token_count=token_counter(OPEN_URL_REMINDER),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
else:
|
||||
reminder_message = None
|
||||
|
||||
constructed_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=msg_history,
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
)
|
||||
|
||||
research_agent_tools = get_research_agent_additional_tool_definitions(
|
||||
include_think_tool=not is_reasoning_model
|
||||
)
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=constructed_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in current_tools]
|
||||
+ research_agent_tools,
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
reasoning_effort=ReasoningEffort.LOW,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_responses: list[ToolResponse] = []
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
just_ran_web_search = False
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
final_report = generate_intermediate_report(
|
||||
research_topic=research_topic,
|
||||
history=msg_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
user_identity=user_identity,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
placement=Placement(
|
||||
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
|
||||
), # TODO
|
||||
)
|
||||
return ResearchAgentCallResult(
|
||||
intermediate_report=final_report, search_docs=[]
|
||||
)
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
msg_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
msg_history.append(think_tool_response_msg)
|
||||
reasoning_cycles += 1
|
||||
continue
|
||||
else:
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=current_tools,
|
||||
message_history=msg_history,
|
||||
memories=None,
|
||||
user_info=None,
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
# May be better to not do this step, hard to say, needs to be tested
|
||||
skip_search_query_expansion=False,
|
||||
)
|
||||
|
||||
if tool_calls and not tool_responses:
|
||||
failure_messages = create_tool_call_failure_messages(
|
||||
tool_calls[0], token_counter
|
||||
)
|
||||
msg_history.extend(failure_messages)
|
||||
continue
|
||||
|
||||
for tool_response in tool_responses:
|
||||
# Extract tool_call from the response (set by run_tool_calls)
|
||||
if tool_response.tool_call is None:
|
||||
raise ValueError("Tool response missing tool_call reference")
|
||||
|
||||
tool_call = tool_response.tool_call
|
||||
tab_index = tool_call.tab_index
|
||||
|
||||
tool = tools_by_name.get(tool_call.tool_name)
|
||||
if not tool:
|
||||
raise ValueError(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
gathered_documents = search_docs
|
||||
|
||||
# This is used for the Open URL reminder in the next cycle
|
||||
# only do this if the web search tool yielded results
|
||||
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
|
||||
just_ran_web_search = True
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # TODO
|
||||
turn_index=llm_cycle_count
|
||||
+ reasoning_cycles, # TODO (subturn index also)
|
||||
tab_index=tab_index,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning,
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=None,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# Store tool call with function name and arguments in separate layers
|
||||
tool_call_message = tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
msg_history.append(tool_call_msg)
|
||||
|
||||
tool_response_message = tool_response.llm_facing_response
|
||||
tool_response_token_count = token_counter(tool_response_message)
|
||||
|
||||
tool_response_msg = ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=tool_response_token_count,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
msg_history.append(tool_response_msg)
|
||||
|
||||
llm_cycle_count += 1
|
||||
|
||||
# If we've run out of cycles, just try to generate a report from everything so far
|
||||
final_report = generate_intermediate_report(
|
||||
research_topic=research_topic,
|
||||
history=msg_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
user_identity=user_identity,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
placement=Placement(
|
||||
turn_index=turn_index, tab_index=tab_index, sub_turn_index=0
|
||||
), # TODO
|
||||
)
|
||||
return ResearchAgentCallResult(intermediate_report=final_report, search_docs=[])
|
||||
|
||||
|
||||
def run_research_agent_calls(
|
||||
research_agent_calls: list[ToolCallKickoff],
|
||||
tools: list[Tool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
llm: LLM,
|
||||
is_reasoning_model: bool,
|
||||
token_counter: Callable[[str], int],
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> list[ResearchAgentCallResult]:
|
||||
# Run all research agent calls in parallel
|
||||
functions_with_args = [
|
||||
(
|
||||
run_research_agent_call,
|
||||
(
|
||||
research_agent_call,
|
||||
tools,
|
||||
emitter,
|
||||
state_container,
|
||||
llm,
|
||||
is_reasoning_model,
|
||||
token_counter,
|
||||
user_identity,
|
||||
),
|
||||
)
|
||||
for research_agent_call in research_agent_calls
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(
|
||||
functions_with_args,
|
||||
allow_failures=True, # Continue even if some research agent calls fail
|
||||
)
|
||||
27
backend/onyx/tools/force.py
Normal file
27
backend/onyx/tools/force.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
# if the tool is called, then those args are applied instead of what the LLM
|
||||
# wanted to call it with
|
||||
force_use: bool
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
def build_openai_tool_choice_dict(self) -> dict[str, Any]:
|
||||
"""Build dict in the format that OpenAI expects which tells them to use this tool."""
|
||||
return {"type": "function", "name": self.tool_name}
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
46
backend/onyx/tools/message.py
Normal file
46
backend/onyx/tools/message.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages.ai import AIMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
# Langchain has their own version of pydantic which is version 1
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]]
|
||||
) -> ToolMessage:
|
||||
return ToolMessage(
|
||||
tool_call_id=tool_call["id"] or "",
|
||||
name=tool_call["name"],
|
||||
content=tool_content,
|
||||
)
|
||||
|
||||
|
||||
class ToolCallSummary(BaseModel):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
# This is a workaround to allow arbitrary types in the model
|
||||
# TODO: Remove this once we have a better solution
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
) -> int:
|
||||
request_tokens = len(
|
||||
llm_tokenizer.encode(
|
||||
json.dumps(tool_call_summary.tool_call_request.tool_calls[0]["args"])
|
||||
)
|
||||
)
|
||||
result_tokens = len(
|
||||
llm_tokenizer.encode(json.dumps(tool_call_summary.tool_call_result.content))
|
||||
)
|
||||
|
||||
return request_tokens + result_tokens
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user