Compare commits

..

1 Commits

Author SHA1 Message Date
Jessica Singh
9b694767f7 refactor: simplify auth parameter handling in backend 2026-02-18 14:05:07 -08:00
334 changed files with 5954 additions and 13501 deletions

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,10 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -127,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,32 +0,0 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -1,31 +0,0 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -1,31 +0,0 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
import jwt
@@ -20,7 +21,13 @@ logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -1,4 +1,5 @@
import json
import os
import random
import secrets
import string
@@ -121,7 +122,6 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -138,18 +138,28 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
def verify_auth_setting() -> None:
if AUTH_TYPE == AuthType.CLOUD:
"""Log warnings for AUTH_TYPE issues.
This only runs on app startup not during migrations/scripts.
"""
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
if raw_auth_type == "cloud":
raise ValueError(
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
"'cloud' is not a valid auth type for self-hosted deployments."
)
if raw_auth_type == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Using 'basic' instead. Please update your configuration."
)
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
@@ -211,34 +221,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +253,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -1671,10 +1663,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,14 +41,3 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,12 +8,6 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -53,13 +47,7 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
# shared_task allows this task to be shared across celery app instances.
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token map stored in the
tracked via a continuation token stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -169,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token_map,
continuation_token,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
if continuation_token is None and total_chunks_migrated > 0:
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -199,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
f"Continuation token: {continuation_token}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
raw_vespa_chunks, next_continuation_token = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token map: {next_continuation_token_map}"
f"seconds. Next continuation token: {next_continuation_token}"
)
opensearch_document_chunks, errored_chunks = (
@@ -241,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
continuation_token=next_continuation_token,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -37,35 +37,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -13,7 +13,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -373,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -482,42 +481,7 @@ def construct_message_history(
if reminder_message:
result.append(reminder_message)
return _drop_orphaned_tool_call_responses(result)
def _drop_orphaned_tool_call_responses(
messages: list[ChatMessageSimple],
) -> list[ChatMessageSimple]:
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
This can happen when history truncation drops an ASSISTANT tool-call message but
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
reject such history with an "unexpected tool call id" error.
"""
known_tool_call_ids: set[str] = set()
sanitized: list[ChatMessageSimple] = []
for msg in messages:
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
for tool_call in msg.tool_calls:
known_tool_call_ids.add(tool_call.tool_call_id)
sanitized.append(msg)
continue
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
sanitized.append(msg)
else:
logger.debug(
"Dropping orphaned tool response with tool_call_id=%s while "
"constructing message history",
msg.tool_call_id,
)
continue
sanitized.append(msg)
return sanitized
return result
def _create_file_tool_metadata_message(
@@ -652,7 +616,6 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -763,7 +726,6 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -903,18 +865,6 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -20,7 +20,6 @@ from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.file_store.models import ChatFileType
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
@@ -561,23 +560,6 @@ def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
return value
def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None:
"""Extract and parse an arguments/parameters value from a tool-call-like object.
Looks for "arguments" or "parameters" keys, handles JSON-string values,
and returns a dict if successful, or None otherwise.
"""
arguments = obj.get("arguments", obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return arguments
return None
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
@@ -600,8 +582,13 @@ def _try_match_json_to_tool(
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = _resolve_tool_arguments(json_obj)
if arguments is not None:
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
@@ -609,8 +596,13 @@ def _try_match_json_to_tool(
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = _resolve_tool_arguments(func_obj)
if arguments is not None:
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
@@ -677,107 +669,6 @@ def _extract_nested_arguments_obj(
return None
def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage:
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
return AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
class _HistoryMessageFormatter:
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
raise NotImplementedError
def format_tool_response_message(
self, msg: ChatMessageSimple
) -> ToolMessage | UserMessage:
raise NotImplementedError
class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
return _build_structured_assistant_message(msg)
def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage:
return _build_structured_tool_response_message(msg)
class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter):
def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage:
if not msg.tool_calls:
return _build_structured_assistant_message(msg)
tool_call_lines = [
(
f"[Tool Call] name={tc.tool_name} "
f"id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}"
)
for tc in msg.tool_calls
]
assistant_content = (
"\n".join([msg.message, *tool_call_lines])
if msg.message
else "\n".join(tool_call_lines)
)
return AssistantMessage(
role="assistant",
content=assistant_content,
tool_calls=None,
)
def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage:
if not msg.tool_call_id:
raise ValueError(
"Tool call response message encountered but tool_call_id is not available. "
f"Message: {msg}"
)
return UserMessage(
role="user",
content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}",
)
_DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter()
_OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter()
def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter:
if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT:
return _OLLAMA_HISTORY_MESSAGE_FORMATTER
return _DEFAULT_HISTORY_MESSAGE_FORMATTER
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,
@@ -788,10 +679,6 @@ def translate_history_to_llm_format(
handling different message types and image files for multimodal support.
"""
messages: list[ChatCompletionMessage] = []
history_message_formatter = _get_history_message_formatter(llm_config)
# Note: cacheability is computed from pre-translation ChatMessageSimple types.
# Some providers flatten tool history into plain assistant/user text, so this split
# may be less semantically meaningful, but it remains safe and order-preserving.
last_cacheable_msg_idx = -1
all_previous_msgs_cacheable = True
@@ -873,10 +760,39 @@ def translate_history_to_llm_format(
messages.append(reminder_msg)
elif msg.message_type == MessageType.ASSISTANT:
messages.append(history_message_formatter.format_assistant_message(msg))
tool_calls_list: list[ToolCall] | None = None
if msg.tool_calls:
tool_calls_list = [
ToolCall(
id=tc.tool_call_id,
type="function",
function=FunctionCall(
name=tc.tool_name,
arguments=json.dumps(tc.tool_arguments),
),
)
for tc in msg.tool_calls
]
assistant_msg = AssistantMessage(
role="assistant",
content=msg.message or None,
tool_calls=tool_calls_list,
)
messages.append(assistant_msg)
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
messages.append(history_message_formatter.format_tool_response_message(msg))
if not msg.tool_call_id:
raise ValueError(
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
)
tool_msg = ToolMessage(
role="tool",
content=msg.message,
tool_call_id=msg.tool_call_id,
)
messages.append(tool_msg)
else:
logger.warning(
@@ -1002,15 +918,8 @@ def run_llm_step_pkt_generator(
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
def _current_placement() -> Placement:
return Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
)
llm_msg_history = translate_history_to_llm_format(history, llm.config)
has_reasoned = False
has_reasoned = 0
if LOG_ONYX_MODEL_INTERACTIONS:
logger.debug(
@@ -1040,56 +949,12 @@ def run_llm_step_pkt_generator(
stream_start_time = time.monotonic()
first_action_recorded = False
def _emit_citation_results(
results: Generator[str | CitationInfo, None, None],
) -> Generator[Packet, None, None]:
"""Yield packets for citation processor results (str or CitationInfo)."""
nonlocal accumulated_answer
for result in results:
if isinstance(result, str):
accumulated_answer += result
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=_current_placement(),
obj=result,
)
if state_container:
state_container.add_emitted_citation(result.citation_number)
def _close_reasoning_if_active() -> Generator[Packet, None, None]:
"""Emit ReasoningDone and increment turns if reasoning is in progress."""
nonlocal reasoning_start
nonlocal has_reasoned
nonlocal turn_index
nonlocal sub_turn_index
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = True
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
nonlocal accumulated_answer
nonlocal accumulated_reasoning
nonlocal answer_start
nonlocal reasoning_start
nonlocal has_reasoned
nonlocal turn_index
nonlocal sub_turn_index
@@ -1102,18 +967,39 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=content_chunk),
)
reasoning_start = True
return
# Normal flow for AUTO or NONE tool choice
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
# Store pre-answer processing time in state container for save_chat
@@ -1123,7 +1009,11 @@ def run_llm_step_pkt_generator(
)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseStart(
final_documents=final_documents,
pre_answer_processing_seconds=pre_answer_processing_time,
@@ -1132,16 +1022,43 @@ def run_llm_step_pkt_generator(
answer_start = True
if citation_processor:
yield from _emit_citation_results(
citation_processor.process_token(content_chunk)
)
for result in citation_processor.process_token(content_chunk):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
else:
accumulated_answer += content_chunk
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=content_chunk),
)
@@ -1203,11 +1120,19 @@ def run_llm_step_pkt_generator(
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
placement=_current_placement(),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
@@ -1221,7 +1146,20 @@ def run_llm_step_pkt_generator(
yield from _emit_content_chunk(filtered_content)
if delta.tool_calls:
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
@@ -1286,14 +1224,50 @@ def run_llm_step_pkt_generator(
# This may happen if the custom token processor is used to modify other packets into reasoning
# Then there won't necessarily be anything else to come after the reasoning tokens
yield from _close_reasoning_if_active()
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
reasoning_start = False
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
yield from _emit_citation_results(citation_processor.process_token(None))
for result in citation_processor.process_token(None):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
# Track emitted citation for saving
if state_container:
state_container.add_emitted_citation(result.citation_number)
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
@@ -1317,7 +1291,7 @@ def run_llm_step_pkt_generator(
tool_calls=tool_calls if tool_calls else None,
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
),
has_reasoned,
bool(has_reasoned),
)
@@ -1372,4 +1346,4 @@ def run_llm_step(
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, has_reasoned
return llm_step_result, bool(has_reasoned)

View File

@@ -10,7 +10,6 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -126,7 +125,6 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -134,8 +132,6 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None
@@ -190,7 +186,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "\n".join(sections)
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
@@ -228,21 +224,23 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -252,14 +250,12 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
system_prompt += INTERNAL_SEARCH_GUIDANCE
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -269,23 +265,20 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
)
if has_open_urls or include_all_guidance:
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
system_prompt += OPEN_URLS_GUIDANCE
if has_python or include_all_guidance:
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
system_prompt += PYTHON_TOOL_GUIDANCE
if has_generate_image or include_all_guidance:
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -85,19 +85,12 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
#####
# Auth Configs
#####
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
"Defaulting to 'basic'. Please update your configuration. "
"Your existing data will be migrated automatically."
)
_auth_type_str = AuthType.BASIC.value
try:
# Silently default to basic - warnings/errors logged in verify_auth_setting()
# which only runs on app startup, not during migrations/scripts
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
if _auth_type_str in [auth_type.value for auth_type in AuthType]:
AUTH_TYPE = AuthType(_auth_type_str)
except ValueError:
logger.error(f"Invalid AUTH_TYPE: {_auth_type_str}. Defaulting to 'basic'.")
else:
AUTH_TYPE = AuthType.BASIC
PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8))
@@ -251,9 +244,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -265,18 +256,6 @@ OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
@@ -284,9 +263,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
@@ -642,14 +618,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

File diff suppressed because it is too large Load Diff

View File

@@ -50,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -66,15 +63,11 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -83,7 +76,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{self.authority_host}/{teams_directory_id}"
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -98,7 +91,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):

View File

@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session | None = None,
db_session: Session,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -928,7 +925,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session (optional if search_settings provided)
db_session: Database session
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1156,10 +1153,7 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -43,7 +41,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -51,8 +49,6 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
@@ -107,14 +103,9 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
@@ -261,15 +252,11 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -310,7 +297,6 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -329,8 +315,6 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -52,14 +50,9 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
db_session: Session,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
query_embedding = get_query_embedding(query_request.query, db_session)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -85,9 +78,7 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
db_session: Session,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -97,22 +88,14 @@ def search_chunks(
else None
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -131,10 +114,7 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
(_embed_and_search, (query_request, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
.options(selectinload(DocumentSetDBModel.federated_connectors))
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_documents_for_document_set_paginated(

View File

@@ -232,12 +232,6 @@ class BuildSessionStatus(str, PyEnum):
IDLE = "idle"
class SharingScope(str, PyEnum):
PRIVATE = "private"
PUBLIC_ORG = "public_org"
PUBLIC_GLOBAL = "public_global"
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"

View File

@@ -430,7 +430,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_type_filter: list[LLMModelFlowType],
flow_types: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -438,27 +438,30 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_type_filter: List of flow types to filter by, empty list for no filter
flow_types: List of flow types to filter by
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),

View File

@@ -77,7 +77,6 @@ from onyx.db.enums import (
ThemePreference,
DefaultAppMode,
SwitchoverType,
SharingScope,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -287,7 +286,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
"Credential", back_populates="user", lazy="joined"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,6 +320,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -1040,9 +1040,7 @@ class OpenSearchTenantMigrationRecord(Base):
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started".
# Otherwise contains a serialized mapping between slice ID and continuation
# token for that slice.
# NULL means "not started" or "visit completed".
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
@@ -1066,9 +1064,6 @@ class OpenSearchTenantMigrationRecord(Base):
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
Integer, nullable=True
)
class KGEntityType(Base):
@@ -4717,12 +4712,6 @@ class BuildSession(Base):
demo_data_enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
sharing_scope: Mapped[SharingScope] = mapped_column(
String,
nullable=False,
default=SharingScope.PRIVATE,
server_default="private",
)
# Relationships
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
@@ -4939,7 +4928,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4978,12 +4966,3 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -4,7 +4,6 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
import json
from datetime import datetime
from datetime import timezone
@@ -13,9 +12,6 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
@@ -247,37 +243,29 @@ def should_document_migration_be_permanently_failed(
def get_vespa_visit_state(
db_session: Session,
) -> tuple[dict[int, str | None], int]:
) -> tuple[str | None, int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token_map, total_chunks_migrated).
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
if record.vespa_visit_continuation_token is None:
continuation_token_map: dict[int, str | None] = {
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
}
else:
json_loaded_continuation_token_map = json.loads(
record.vespa_visit_continuation_token
)
continuation_token_map = {
int(key): value for key, value in json_loaded_continuation_token_map.items()
}
return continuation_token_map, record.total_chunks_migrated
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
chunks_processed: int,
chunks_errored: int,
approx_chunk_count_in_vespa: int | None,
) -> None:
"""Updates the Vespa migration progress and commits.
@@ -285,26 +273,19 @@ def update_vespa_visit_progress_with_commit(
Args:
db_session: SQLAlchemy session.
continuation_token_map: The new continuation token map. None entry means
the visit is complete for that slice.
continuation_token: The new continuation token. None means the visit
is complete.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
None, the existing value is used.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
record.vespa_visit_continuation_token = continuation_token
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
record.approx_chunk_count_in_vespa = (
approx_chunk_count_in_vespa
if approx_chunk_count_in_vespa is not None
else record.approx_chunk_count_in_vespa
)
db_session.commit()
@@ -372,27 +353,25 @@ def build_sanitized_to_original_doc_id_mapping(
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None, int | None]:
) -> tuple[int, datetime | None, datetime | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None, None.
None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
approx_chunk_count_in_vespa).
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None, None
return 0, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
record.approx_chunk_count_in_vespa,
)

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
.limit(1)
)
if not user:
row = result.first()
if not row:
return None
_schedule_pat_last_used_update(hashed_token, now)
return user
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
asyncio.create_task(_update_last_used())
return user
def create_pat(

View File

@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.user),
)
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

View File

@@ -54,9 +54,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
@@ -709,12 +706,10 @@ class OpenSearchClient:
)
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
)
search_hits.append(search_hit)
logger.debug(

View File

@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# Number of vectors to examine for top k neighbors for the HNSW method.
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
EF_SEARCH = 256
# The default number of neighbors to consider for knn vector similarity search.
# We need this higher than the number of results because the scoring is hybrid.
# If there is only 1 query, setting k equal to the number of results is enough,
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_TITLE_KEYWORD_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
SEARCH_CONTENT_KEYWORD_WEIGHT,
]
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0

View File

@@ -842,8 +842,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -11,7 +11,6 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -55,11 +54,6 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
# Faiss was also tried but it didn't have any benefits
# NMSLIB is deprecated, not recommended
OPENSEARCH_KNN_ENGINE = "lucene"
def get_opensearch_doc_chunk_id(
tenant_state: TenantState,
document_id: str,
@@ -349,9 +343,6 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
@@ -366,7 +357,9 @@ class DocumentSchema:
CONTENT_FIELD_NAME: {
"type": "text",
"store": True,
"analyzer": OPENSEARCH_TEXT_ANALYZER,
# This makes highlighting text during queries more efficient
# at the cost of disk space. See
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
"index_options": "offsets",
},
TITLE_VECTOR_FIELD_NAME: {
@@ -375,7 +368,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
@@ -387,7 +380,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},

View File

@@ -6,16 +6,13 @@ from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
@@ -243,9 +240,6 @@ class DocumentQuery:
Returns:
A dictionary representing the final hybrid search query.
"""
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
@@ -253,7 +247,7 @@ class DocumentQuery:
)
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -281,31 +275,25 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
# Sources:
# Applied to all the sub-queries. Source:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
# Explain is for scoring breakdowns.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
return final_hybrid_search_body
@@ -367,12 +355,7 @@ class DocumentQuery:
@staticmethod
def _get_hybrid_search_subqueries(
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
query_text: str, query_vector: list[float], num_candidates: int
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -384,8 +367,9 @@ class DocumentQuery:
Matches:
- Title vector
- Title keyword
- Content vector
- Keyword (title + content, match and phrase)
- Content keyword + phrase
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
@@ -406,9 +390,9 @@ class DocumentQuery:
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
In testing datasets, this makes recall slightly worse.
Args:
query_text: The text of the query to search for.
@@ -417,27 +401,19 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, content vector, keyword (title + content).
# pipeline weights: title vector, title keyword, content vector,
# content keyword.
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
"k": num_candidates,
}
}
},
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 3. Keyword (title + content) match and phrase search.
# 2. Title keyword + phrase search.
{
"bool": {
"should": [
@@ -445,10 +421,8 @@ class DocumentQuery:
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
@@ -456,17 +430,35 @@ class DocumentQuery:
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
"boost": 0.2,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}
},
]
}
},
# 3. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
}
}
},
# 4. Content keyword + phrase search.
{
"bool": {
"should": [
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
"boost": 1.0,
}
}
},
@@ -474,7 +466,9 @@ class DocumentQuery:
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}

View File

@@ -10,12 +10,6 @@ from typing import cast
import httpx
from retry import retry
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
FIELDS_NEEDED_FOR_TRANSFORMATION,
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.context.search.models import IndexFilters
@@ -283,139 +277,54 @@ def get_chunks_via_visit_api(
def get_all_chunks_paginated(
index_name: str,
tenant_state: TenantState,
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict], dict[int, str | None]]:
continuation_token: str | None = None,
page_size: int = 1_000,
) -> tuple[list[dict], str | None]:
"""Gets all chunks in Vespa matching the filters, paginated.
Uses the Visit API with slicing. Each continuation token map entry is for a
different slice. The number of entries determines the number of slices.
Args:
index_name: The name of the Vespa index to visit.
tenant_state: The tenant state to filter by.
continuation_token_map: Map of slice ID to a token returned by Vespa
representing a page offset. None to start from the beginning of the
slice.
continuation_token: Token returned by Vespa representing a page offset.
None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
def _get_all_chunks_paginated_for_slice(
index_name: str,
tenant_state: TenantState,
slice_id: int,
total_slices: int,
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict], str | None]:
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
)
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
params: dict[str, str | int | None] = {
"selection": selection,
"fieldSet": field_set,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
}
if continuation_token is not None:
params["continuation"] = continuation_token
response: httpx.Response | None = None
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
error_message = (
response.json().get("message") if response else "No response"
)
logger.error("Error message from response: %s", error_message)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
# NOTE: If we see a falsey value for "continuation" in the response we
# assume we are done and return
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
next_continuation_token = (
response_data.get("continuation")
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
)
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
)
return chunks, next_continuation_token
total_slices = len(continuation_token_map)
if total_slices < 1:
raise ValueError("continuation_token_map must have at least one entry.")
# We want to guarantee that these invocations are ordered by slice_id,
# because we read in the same order below when parsing parallel_results.
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_all_chunks_paginated_for_slice,
(
index_name,
tenant_state,
slice_id,
total_slices,
continuation_token,
page_size,
),
)
for slice_id, continuation_token in sorted(continuation_token_map.items())
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
if len(parallel_results) != total_slices:
raise RuntimeError(
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
)
chunks: list[dict] = []
next_continuation_token_map: dict[int, str | None] = {
key: value for key, value in continuation_token_map.items()
params: dict[str, str | int | None] = {
"selection": selection,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
}
for i, parallel_result in enumerate(parallel_results):
if i not in next_continuation_token_map:
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
if parallel_result is None:
logger.error(
f"Failed to get chunks for slice {i} of {total_slices}. "
"The continuation token for this slice will not be updated."
)
continue
chunks.extend(parallel_result[0])
next_continuation_token_map[i] = parallel_result[1]
if continuation_token is not None:
params["continuation"] = continuation_token
return chunks, next_continuation_token_map
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to get chunks in Vespa."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
return [
chunk["fields"] for chunk in response_data.get("documents", [])
], response_data.get("continuation") or None
# TODO(rkuo): candidate for removal if not being used

View File

@@ -56,7 +56,6 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -653,9 +652,9 @@ class VespaDocumentIndex(DocumentIndex):
def get_all_raw_document_chunks_paginated(
self,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
) -> tuple[list[dict[str, Any]], str | None]:
"""Gets all the chunks in Vespa, paginated.
Used in the chunk-level Vespa-to-OpenSearch migration task.
@@ -663,21 +662,21 @@ class VespaDocumentIndex(DocumentIndex):
Args:
continuation_token: Token returned by Vespa representing a page
offset. None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
raw_chunks, next_continuation_token = get_all_chunks_paginated(
index_name=self._index_name,
tenant_state=TenantState(
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
),
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=page_size,
)
return raw_chunks, next_continuation_token_map
return raw_chunks, next_continuation_token
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
"""Indexes raw document chunks into Vespa.
@@ -703,32 +702,3 @@ class VespaDocumentIndex(DocumentIndex):
json={"fields": chunk},
)
response.raise_for_status()
def get_chunk_count(self) -> int:
"""Returns the exact number of document chunks in Vespa for this tenant.
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
to get an exact count without fetching any document data.
Includes large chunks. There is no way to filter these out using the
Search API.
"""
where_clause = (
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
)
yql = (
f"select documentid from {self._index_name} "
f"where {where_clause} "
f"limit 0"
)
params: dict[str, str | int] = {
"yql": yql,
"ranking.profile": "unranked",
"timeout": VESPA_TIMEOUT,
}
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
response_data = response.json()
return response_data["root"]["fields"]["totalCount"]

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,7 +1,5 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def _load_bundled_recommendations() -> LLMRecommendations:
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def is_obsolete_model(model_name: str, provider: str) -> bool:
@@ -243,23 +215,6 @@ def model_configurations_for_provider(
) -> list[ModelConfigurationView]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
# Preserve provider-defined ordering while de-duplicating.
model_names: list[str] = []
seen_model_names: set[str] = set()
for model_name in (
fetch_models_for_provider(provider_name) + recommended_visible_models_names
):
if model_name in seen_model_names:
continue
seen_model_names.add(model_name)
model_names.append(model_name)
# Vertex model list can be large and mixed-vendor; alphabetical ordering
# makes model discovery easier in admin selection UIs.
if provider_name == VERTEXAI_PROVIDER_NAME:
model_names = sorted(model_names, key=str.lower)
return [
ModelConfigurationView(
name=model_name,
@@ -267,7 +222,8 @@ def model_configurations_for_provider(
max_input_tokens=get_max_input_tokens(model_name, provider_name),
supports_image_input=model_supports_image_input(model_name, provider_name),
)
for model_name in model_names
for model_name in set(fetch_models_for_provider(provider_name))
| set(recommended_visible_models_names)
]

View File

@@ -52,7 +52,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -64,7 +63,7 @@ from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.build.api.api import public_build_router
from onyx.server.features.build.api.api import nextjs_assets_router
from onyx.server.features.build.api.api import router as build_router
from onyx.server.features.default_assistant.api import (
router as default_assistant_router,
@@ -115,16 +114,13 @@ from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.pat.api import router as pat_router
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
from onyx.server.query_and_chat.chat_backend import router as chat_router
from onyx.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -142,7 +138,6 @@ from onyx.setup import setup_onyx
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_endpoint_context_middleware
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
@@ -271,17 +266,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
# Register pool metrics now that engines are created.
# HTTP instrumentation is set up earlier in get_application() since it
# adds middleware (which Starlette forbids after the app has started).
setup_postgres_connection_pool_metrics(
engines={
"sync": SqlEngine.get_engine(),
"async": get_sqlalchemy_async_engine(),
"readonly": SqlEngine.get_readonly_engine(),
},
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -394,8 +378,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -576,18 +560,12 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
add_onyx_request_id_middleware(application, "API", logger)
# Set endpoint context for per-endpoint DB pool attribution metrics.
# Must be registered after all routes are added.
add_endpoint_context_middleware(application)
# HTTP request metrics (latency histograms, in-progress gauge, slow request
# counter). Must be called here — before the app starts — because the
# instrumentator adds middleware via app.add_middleware().
setup_prometheus_metrics(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app with production Prometheus config
setup_prometheus_metrics(application)
use_route_function_names_as_operation_ids(application)
return application

View File

@@ -69,12 +69,6 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -1,6 +1,6 @@
# ruff: noqa: E501, W605 start
# If there are any tools, this section is included, the sections below are for the available tools
TOOL_SECTION_HEADER = "\n# Tools\n\n"
TOOL_SECTION_HEADER = "\n\n# Tools\n"
# This section is included if there are search type tools, currently internal_search and web_search
@@ -16,10 +16,11 @@ When searching for information, if the initial results cannot fully answer the u
Do not repeat the same or very similar queries if it already has been run in the chat history.
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
""".lstrip()
"""
INTERNAL_SEARCH_GUIDANCE = """
## internal_search
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
@@ -27,31 +28,34 @@ Use the `internal_search` tool to search connected applications for information.
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
- Ambiguity: questions about something that is not widely known or understood.
Never provide more than 3 queries at once to `internal_search`.
""".lstrip()
"""
WEB_SEARCH_GUIDANCE = """
## web_search
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
- Accuracy: if the cost of outdated/inaccurate information is high.
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
""".lstrip()
"""
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
Do not use the "site:" operator in your web search queries.
""".lstrip()
""".rstrip()
OPEN_URLS_GUIDANCE = """
## open_url
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
Do not open URLs that are image files like .png, .jpg, etc.
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
""".lstrip()
"""
PYTHON_TOOL_GUIDANCE = """
## python
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
@@ -60,21 +64,21 @@ Use this to give the user a way to download the file OR to display generated ima
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
""".lstrip()
"""
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
""".lstrip()
"""
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
""".lstrip()
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.

View File

@@ -1,36 +1,40 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n# User Information\n\n"
USER_INFORMATION_HEADER = "\n\n# User Information\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
""".lstrip()
"""
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
""".lstrip()
"""
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
""".lstrip()
"""
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
""".lstrip()
"""
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
""".lstrip()
"""
# ruff: noqa: E501, W605 end

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -59,9 +59,6 @@ PUBLIC_ENDPOINT_SPECS = [
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
# craft webapp proxy — access enforced per-session via sharing_scope in handler
("/build/sessions/{session_id}/webapp", {"GET"}),
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
]

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -988,7 +987,6 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1002,23 +1000,11 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
# Get most recent finished index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
]
if user and user.role == UserRole.ADMIN:
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
),
)
@@ -1859,7 +1831,7 @@ def get_connector_by_id(
@router.post("/connector-request")
def submit_connector_request(
request_data: ConnectorRequestSubmission,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
) -> StatusResponse:
"""
Submit a connector request for Cloud deployments.
@@ -1872,7 +1844,7 @@ def submit_connector_request(
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
# Get user identifier for telemetry
user_email = user.email if user else None
user_email = user.email
distinct_id = user_email or tenant_id
# Track connector request via PostHog telemetry (Cloud only)
@@ -1939,7 +1911,6 @@ Tenant ID: {tenant_id}
class BasicCCPairInfo(BaseModel):
has_successful_run: bool
source: DocumentSource
status: ConnectorCredentialPairStatus
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
@@ -1953,17 +1924,13 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
status=cc_pair.status,
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
cc_pair_model.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -1,5 +1,4 @@
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
import httpx
@@ -8,19 +7,16 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import ProcessingMode
from onyx.db.enums import SharingScope
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -221,15 +217,12 @@ def get_build_connectors(
return BuildConnectorListResponse(connectors=connectors)
# Headers to skip when proxying.
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
# Headers to skip when proxying (hop-by-hop headers)
EXCLUDED_HEADERS = {
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
"set-cookie",
}
@@ -287,7 +280,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
db_session: Database session
Returns:
Internal URL to proxy requests to
The internal URL to proxy requests to
Raises:
HTTPException: If session not found, port not allocated, or sandbox not found
@@ -301,10 +294,12 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
if session.user_id is None:
raise HTTPException(status_code=404, detail="User not found")
# Get the user's sandbox to get the sandbox_id
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
if sandbox is None:
raise HTTPException(status_code=404, detail="Sandbox not found")
# Use sandbox manager to get the correct internal URL
sandbox_manager = get_sandbox_manager()
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
@@ -370,73 +365,71 @@ def _proxy_request(
raise HTTPException(status_code=502, detail="Bad gateway")
def _check_webapp_access(
session_id: UUID, user: User | None, db_session: Session
) -> BuildSession:
"""Check if user can access a session's webapp.
- public_global: accessible by anyone (no auth required)
- public_org: accessible by any authenticated user
- private: only accessible by the session owner
"""
session = db_session.get(BuildSession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
return session
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
raise HTTPException(status_code=404, detail="Session not found")
return session
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
"""Return a branded Craft HTML page when the sandbox is not reachable.
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")
# Public router for webapp proxy — no authentication required
# (access controlled per-session via sharing_scope)
public_build_router = APIRouter(prefix="/build")
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
@public_build_router.get(
"/sessions/{session_id}/webapp/{path:path}", response_model=None
)
def get_webapp(
@router.get("/sessions/{session_id}/webapp", response_model=None)
def get_webapp_root(
session_id: UUID,
request: Request,
path: str = "",
user: User | None = Depends(optional_user),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy the webapp for a specific session (root and subpaths).
"""Proxy the root path of the webapp for a specific session."""
return _proxy_request("", request, session_id, db_session)
Accessible without authentication when sharing_scope is public_global.
Returns a friendly offline page when the sandbox is not running.
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
def get_webapp_path(
session_id: UUID,
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
return _proxy_request(path, request, session_id, db_session)
# Separate router for Next.js static assets at /_next/*
# This is needed because Next.js apps may reference assets with root-relative paths
# that don't get rewritten. The session_id is extracted from the Referer header.
nextjs_assets_router = APIRouter()
def _extract_session_from_referer(request: Request) -> UUID | None:
"""Extract session_id from the Referer header.
Expects Referer to contain /api/build/sessions/{session_id}/webapp
"""
try:
_check_webapp_access(session_id, user, db_session)
except HTTPException as e:
if e.status_code == 401:
return RedirectResponse(url="/auth/login", status_code=302)
raise
try:
return _proxy_request(path, request, session_id, db_session)
except HTTPException as e:
if e.status_code in (502, 503, 504):
return _offline_html_response()
raise
import re
referer = request.headers.get("referer", "")
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
if match:
try:
return UUID(match.group(1))
except ValueError:
return None
return None
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
def get_nextjs_assets(
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy Next.js static assets requested at root /_next/ path.
The session_id is extracted from the Referer header since these requests
come from within the iframe context.
"""
session_id = _extract_session_from_referer(request)
if not session_id:
raise HTTPException(
status_code=400,
detail="Could not determine session from request context",
)
return _proxy_request(f"_next/{path}", request, session_id, db_session)
# =============================================================================

View File

@@ -57,9 +57,6 @@ def list_messages(
db_session: Session = Depends(get_session),
) -> MessageListResponse:
"""Get all messages for a build session."""
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
session_manager = SessionManager(db_session)
messages = session_manager.list_messages(session_id, user.id)

View File

@@ -10,7 +10,6 @@ from onyx.configs.constants import MessageType
from onyx.db.enums import ArtifactType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.server.features.build.sandbox.models import (
FilesystemEntry as FileSystemEntry,
)
@@ -108,7 +107,6 @@ class SessionResponse(BaseModel):
nextjs_port: int | None
sandbox: SandboxResponse | None
artifacts: list[ArtifactResponse]
sharing_scope: SharingScope
@classmethod
def from_model(
@@ -131,7 +129,6 @@ class SessionResponse(BaseModel):
nextjs_port=session.nextjs_port,
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
sharing_scope=session.sharing_scope,
)
@@ -162,19 +159,6 @@ class SessionListResponse(BaseModel):
sessions: list[SessionResponse]
class SetSessionSharingRequest(BaseModel):
"""Request to set the sharing scope of a session."""
sharing_scope: SharingScope
class SetSessionSharingResponse(BaseModel):
"""Response after setting session sharing scope."""
session_id: str
sharing_scope: SharingScope
# ===== Message Models =====
class MessageRequest(BaseModel):
"""Request to send a message to the CLI agent."""
@@ -260,7 +244,6 @@ class WebappInfo(BaseModel):
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
sharing_scope: SharingScope
# ===== File Upload Models =====

View File

@@ -30,8 +30,6 @@ from onyx.server.features.build.api.models import SessionListResponse
from onyx.server.features.build.api.models import SessionNameGenerateResponse
from onyx.server.features.build.api.models import SessionResponse
from onyx.server.features.build.api.models import SessionUpdateRequest
from onyx.server.features.build.api.models import SetSessionSharingRequest
from onyx.server.features.build.api.models import SetSessionSharingResponse
from onyx.server.features.build.api.models import SuggestionBubble
from onyx.server.features.build.api.models import SuggestionTheme
from onyx.server.features.build.api.models import UploadResponse
@@ -40,7 +38,6 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import allocate_nextjs_port
from onyx.server.features.build.db.build_session import get_build_session
from onyx.server.features.build.db.build_session import set_build_session_sharing_scope
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
@@ -297,25 +294,6 @@ def update_session_name(
return SessionResponse.from_model(session, sandbox)
@router.patch("/{session_id}/public")
def set_session_public(
session_id: UUID,
request: SetSessionSharingRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SetSessionSharingResponse:
"""Set the sharing scope of a build session's webapp."""
updated = set_build_session_sharing_scope(
session_id, user.id, request.sharing_scope, db_session
)
if not updated:
raise HTTPException(status_code=404, detail="Session not found")
return SetSessionSharingResponse(
session_id=str(session_id),
sharing_scope=updated.sharing_scope,
)
@router.delete("/{session_id}", response_model=None)
def delete_session(
session_id: UUID,

View File

@@ -1,110 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="refresh" content="15" />
<title>Craft — Starting up</title>
<style>
*,
*::before,
*::after {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas,
monospace;
background: linear-gradient(to bottom right, #030712, #111827, #030712);
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 1.5rem;
padding: 2rem;
}
.terminal {
width: 100%;
max-width: 580px;
border: 2px solid #374151;
border-radius: 2px;
}
.titlebar {
background: #1f2937;
padding: 0.5rem 0.75rem;
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid #374151;
}
.btn {
width: 12px;
height: 12px;
border-radius: 2px;
flex-shrink: 0;
}
.btn-red {
background: #ef4444;
}
.btn-yellow {
background: #eab308;
}
.btn-green {
background: #22c55e;
}
.title-label {
flex: 1;
text-align: center;
font-size: 0.75rem;
color: #6b7280;
margin-right: 36px;
}
.body {
background: #111827;
padding: 1.5rem;
min-height: 200px;
font-size: 0.875rem;
color: #d1d5db;
display: flex;
align-items: flex-start;
gap: 0.375rem;
}
.prompt {
color: #10b981;
user-select: none;
}
.tagline {
font-size: 0.8125rem;
color: #4b5563;
text-align: center;
}
</style>
</head>
<body>
<div class="terminal">
<div class="titlebar">
<div class="btn btn-red"></div>
<div class="btn btn-yellow"></div>
<div class="btn btn-green"></div>
<span class="title-label">crafting_table</span>
</div>
<div class="body">
<span class="prompt">/&gt;</span>
<span>Sandbox is asleep...</span>
</div>
</div>
<p class="tagline">
Ask the owner to open their Craft session to wake it up.
</p>
</body>
</html>

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import MessageType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.db.models import Artifact
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
@@ -160,26 +159,6 @@ def update_session_status(
logger.info(f"Updated build session {session_id} status to {status}")
def set_build_session_sharing_scope(
session_id: UUID,
user_id: UUID,
sharing_scope: SharingScope,
db_session: Session,
) -> BuildSession | None:
"""Set the sharing scope of a build session.
Only the session owner can change this setting.
Returns the updated session, or None if not found/unauthorized.
"""
session = get_build_session(session_id, user_id, db_session)
if not session:
return None
session.sharing_scope = sharing_scope
db_session.commit()
logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}")
return session
def delete_build_session__no_commit(
session_id: UUID,
user_id: UUID,

View File

@@ -474,23 +474,6 @@ class SandboxManager(ABC):
"""
...
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Ensure the Next.js server is running for a session.
Default is a no-op — only meaningful for local backends that manage
process lifecycles directly (e.g., LocalSandboxManager).
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port the Next.js server should be listening on
"""
# Singleton instance cache for the factory
_sandbox_manager_instance: SandboxManager | None = None

View File

@@ -15,8 +15,6 @@ from collections.abc import Generator
from pathlib import Path
from uuid import UUID
import httpx
from onyx.db.enums import SandboxStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.configs import DEMO_DATA_PATH
@@ -37,7 +35,6 @@ from onyx.server.features.build.sandbox.models import LLMProviderConfig
from onyx.server.features.build.sandbox.models import SandboxInfo
from onyx.server.features.build.sandbox.models import SnapshotResult
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import ThreadSafeSet
logger = setup_logger()
@@ -92,17 +89,9 @@ class LocalSandboxManager(SandboxManager):
self._acp_clients: dict[tuple[UUID, UUID], ACPAgentClient] = {}
# Track Next.js processes - keyed by (sandbox_id, session_id) tuple
# Used for clean shutdown when sessions are deleted.
# Mutated from background threads; all access must hold _nextjs_lock.
# Used for clean shutdown when sessions are deleted
self._nextjs_processes: dict[tuple[UUID, UUID], subprocess.Popen[bytes]] = {}
# Track sessions currently being (re)started - prevents concurrent restarts.
# ThreadSafeSet allows atomic check-and-add without holding _nextjs_lock.
self._nextjs_starting: ThreadSafeSet[tuple[UUID, UUID]] = ThreadSafeSet()
# Lock guarding _nextjs_processes (shared across sessions; hold briefly only)
self._nextjs_lock = threading.Lock()
# Validate templates exist (raises RuntimeError if missing)
self._validate_templates()
@@ -337,18 +326,16 @@ class LocalSandboxManager(SandboxManager):
RuntimeError: If termination fails
"""
# Stop all Next.js processes for this sandbox (keyed by (sandbox_id, session_id))
with self._nextjs_lock:
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
for key, process in processes_to_stop:
session_id = key[1]
try:
self._stop_nextjs_process(process, session_id)
with self._nextjs_lock:
self._nextjs_processes.pop(key, None)
del self._nextjs_processes[key]
except Exception as e:
logger.warning(
f"Failed to stop Next.js for sandbox {sandbox_id}, "
@@ -529,8 +516,7 @@ class LocalSandboxManager(SandboxManager):
web_dir, nextjs_port
)
# Store process for clean shutdown on session delete
with self._nextjs_lock:
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
logger.info("Next.js server started successfully")
# Setup venv and AGENTS.md
@@ -589,8 +575,7 @@ class LocalSandboxManager(SandboxManager):
"""
# Stop Next.js dev server - try stored process first, then fallback to port lookup
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
nextjs_process = self._nextjs_processes.pop(process_key, None)
nextjs_process = self._nextjs_processes.pop(process_key, None)
if nextjs_process is not None:
self._stop_nextjs_process(nextjs_process, session_id)
elif nextjs_port is not None:
@@ -781,85 +766,6 @@ class LocalSandboxManager(SandboxManager):
outputs_path = session_path / "outputs"
return outputs_path.exists()
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Start Next.js server for a session if not already running.
Called when the server is detected as unreachable (e.g., after API server restart).
Returns immediately — the actual startup runs in a background daemon thread.
A per-session guard prevents concurrent restarts from racing.
Lock design: _nextjs_lock is shared across ALL sessions. Holding it during
httpx (1s) or start_nextjs_server (several seconds) would block every other
session's status checks and restarts. We only hold the lock for fast
in-memory ops (dict get, check_and_add). The slow I/O runs in the background
thread without holding any lock.
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port number for the Next.js server
"""
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
existing = self._nextjs_processes.get(process_key)
if existing is not None and existing.poll() is None:
return
# Atomic check-and-add: returns True if already in set (another thread is starting)
if self._nextjs_starting.check_and_add(process_key):
return
def _start_in_background() -> None:
try:
# Port check in background to avoid blocking the main thread
try:
with httpx.Client(timeout=1.0) as client:
client.get(f"http://localhost:{nextjs_port}")
logger.info(
f"Port {nextjs_port} already alive for session {session_id} "
"(orphan process) — skipping restart"
)
return
except Exception:
pass # Port is dead; proceed with restart
logger.info(
f"Starting Next.js for session {session_id} on port {nextjs_port}"
)
sandbox_path = self._get_sandbox_path(sandbox_id)
web_dir = self._directory_manager.get_web_path(
sandbox_path, str(session_id)
)
if not web_dir.exists():
logger.warning(
f"Web dir missing for session {session_id}: {web_dir}"
"cannot restart Next.js"
)
return
process = self._process_manager.start_nextjs_server(
web_dir, nextjs_port
)
with self._nextjs_lock:
self._nextjs_processes[process_key] = process
logger.info(
f"Auto-restarted Next.js for session {session_id} "
f"on port {nextjs_port}"
)
except Exception as e:
logger.error(
f"Failed to auto-restart Next.js for session {session_id}: {e}"
)
finally:
self._nextjs_starting.discard(process_key)
threading.Thread(target=_start_in_background, daemon=True).start()
def restore_snapshot(
self,
sandbox_id: UUID,

View File

@@ -0,0 +1,10 @@
"""Celery tasks for sandbox management."""
from onyx.server.features.build.sandbox.tasks.tasks import (
cleanup_idle_sandboxes_task,
) # noqa: F401
from onyx.server.features.build.sandbox.tasks.tasks import (
sync_sandbox_files,
) # noqa: F401
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]

View File

@@ -1765,7 +1765,6 @@ class SessionManager:
"webapp_url": None,
"status": "no_sandbox",
"ready": False,
"sharing_scope": session.sharing_scope,
}
# Return the proxy URL - the proxy handles routing to the correct sandbox
@@ -1778,21 +1777,11 @@ class SessionManager:
# Quick health check: can the API server reach the NextJS dev server?
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
# If not ready, ask the sandbox manager to ensure Next.js is running.
# For the local backend this triggers a background restart so that the
# frontend poll loop eventually sees ready=True without the user having
# to manually recreate the session.
if not ready:
self._sandbox_manager.ensure_nextjs_running(
sandbox.id, session_id, session.nextjs_port
)
return {
"has_webapp": session.nextjs_port is not None,
"webapp_url": webapp_url,
"status": sandbox.status.value,
"ready": ready,
"sharing_scope": session.sharing_scope,
}
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:

View File

@@ -111,8 +111,7 @@ class DocumentSet(BaseModel):
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=[cc_pair.credential_id],
cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential

View File

@@ -30,42 +30,27 @@ OPENSEARCH_NOT_ENABLED_MESSAGE = (
"OpenSearch indexing must be enabled to use this feature."
)
MIGRATION_STATUS_MESSAGE = (
"Our records indicate that the transition to OpenSearch is still in progress. "
"OpenSearch retrieval is necessary to use this feature. "
"You can still use Document Sets, though! "
"If you would like to manually switch to OpenSearch, "
'Go to the "Document Index Migration" section in the Admin panel.'
)
router = APIRouter(prefix=HIERARCHY_NODES_PREFIX)
def _require_opensearch(db_session: Session) -> None:
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX or not get_opensearch_retrieval_state(
db_session
):
raise HTTPException(
status_code=403,
detail=OPENSEARCH_NOT_ENABLED_MESSAGE,
)
if not get_opensearch_retrieval_state(db_session):
raise HTTPException(
status_code=403,
detail=MIGRATION_STATUS_MESSAGE,
)
def _get_user_access_info(
user: User | None, db_session: Session
) -> tuple[str | None, list[str]]:
if not user:
return None, []
def _get_user_access_info(user: User, db_session: Session) -> tuple[str, list[str]]:
return user.email, get_user_external_group_ids(db_session, user)
@router.get(HIERARCHY_NODES_LIST_PATH)
def list_accessible_hierarchy_nodes(
source: DocumentSource,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodesResponse:
_require_opensearch(db_session)
@@ -92,7 +77,7 @@ def list_accessible_hierarchy_nodes(
@router.post(HIERARCHY_NODE_DOCUMENTS_PATH)
def list_accessible_hierarchy_node_documents(
documents_request: HierarchyNodeDocumentsRequest,
user: User | None = Depends(current_user),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodeDocumentsResponse:
_require_opensearch(db_session)

View File

@@ -1013,7 +1013,7 @@ def get_mcp_servers_for_assistant(
@router.get("/servers", response_model=MCPServersResponse)
def get_mcp_servers_for_user(
db: Session = Depends(get_session),
user: User | None = Depends(current_user),
user: User = Depends(current_user),
) -> MCPServersResponse:
"""List all MCP servers for use in agent configuration and chat UI.

View File

@@ -310,7 +310,7 @@ def list_llm_providers(
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session=db_session,
flow_type_filter=[],
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
@@ -568,7 +568,9 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(db_session, [])
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN

View File

@@ -26,17 +26,13 @@ def get_opensearch_migration_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> OpenSearchMigrationStatusResponse:
(
total_chunks_migrated,
created_at,
migration_completed_at,
approx_chunk_count_in_vespa,
) = get_opensearch_migration_state(db_session)
total_chunks_migrated, created_at, migration_completed_at = (
get_opensearch_migration_state(db_session)
)
return OpenSearchMigrationStatusResponse(
total_chunks_migrated=total_chunks_migrated,
created_at=created_at,
migration_completed_at=migration_completed_at,
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)

View File

@@ -8,7 +8,6 @@ class OpenSearchMigrationStatusResponse(BaseModel):
total_chunks_migrated: int
created_at: datetime | None
migration_completed_at: datetime | None
approx_chunk_count_in_vespa: int | None
class OpenSearchRetrievalStatusRequest(BaseModel):

View File

@@ -608,8 +608,7 @@ def list_all_users_basic_info(
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
if include_api_keys or not is_api_key_email_address(user.email)
]

View File

@@ -1,241 +0,0 @@
"""SQLAlchemy connection pool Prometheus metrics.
Provides production-grade visibility into database connection pool state:
- Pool state gauges (checked-out, idle, overflow, configured size)
- Pool lifecycle counters (checkouts, checkins, creates, invalidations, timeouts)
- Per-endpoint connection attribution (which endpoints hold connections, for how long)
Metrics are collected via two mechanisms:
1. A custom Prometheus Collector that reads pool snapshots on each /metrics scrape
2. SQLAlchemy pool event listeners (checkout, checkin, connect, invalidate) for
counters, histograms, and attribution
"""
import time
from fastapi import Request
from fastapi.responses import JSONResponse
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.pool import ConnectionPoolEntry
from sqlalchemy.pool import PoolProxiedConnection
from sqlalchemy.pool import QueuePool
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
logger = setup_logger()
# --- Pool lifecycle counters (event-driven) ---
_checkout_total = Counter(
"onyx_db_pool_checkout_total",
"Total connection checkouts from the pool",
["engine"],
)
_checkin_total = Counter(
"onyx_db_pool_checkin_total",
"Total connection checkins to the pool",
["engine"],
)
_connections_created_total = Counter(
"onyx_db_pool_connections_created_total",
"Total new database connections created",
["engine"],
)
_invalidations_total = Counter(
"onyx_db_pool_invalidations_total",
"Total connection invalidations",
["engine"],
)
_checkout_timeout_total = Counter(
"onyx_db_pool_checkout_timeout_total",
"Total connection checkout timeouts",
["engine"],
)
# --- Per-endpoint attribution (event-driven) ---
_connections_held = Gauge(
"onyx_db_connections_held_by_endpoint",
"Number of DB connections currently held, by endpoint and engine",
["handler", "engine"],
)
_hold_seconds = Histogram(
"onyx_db_connection_hold_seconds",
"Duration a DB connection is held by an endpoint",
["handler", "engine"],
)
def pool_timeout_handler(
request: Request, # noqa: ARG001
exc: Exception,
) -> JSONResponse:
"""Increment the checkout timeout counter and return 503."""
_checkout_timeout_total.labels(engine="unknown").inc()
return JSONResponse(
status_code=503,
content={
"detail": "Database connection pool timeout",
"error": str(exc),
},
)
class PoolStateCollector(Collector):
"""Custom Prometheus collector that reads QueuePool state on each scrape.
Uses pool.checkedout(), pool.checkedin(), pool.overflow(), and pool.size()
for an atomic snapshot of pool state. Registered engines are stored as
(label, pool) tuples to avoid holding references to the full Engine.
"""
def __init__(self) -> None:
self._pools: list[tuple[str, QueuePool]] = []
def add_pool(self, label: str, pool: QueuePool) -> None:
self._pools.append((label, pool))
def collect(self) -> list[GaugeMetricFamily]:
checked_out = GaugeMetricFamily(
"onyx_db_pool_checked_out",
"Currently checked-out connections",
labels=["engine"],
)
checked_in = GaugeMetricFamily(
"onyx_db_pool_checked_in",
"Idle connections available in the pool",
labels=["engine"],
)
overflow = GaugeMetricFamily(
"onyx_db_pool_overflow",
"Current overflow connections beyond pool_size",
labels=["engine"],
)
size = GaugeMetricFamily(
"onyx_db_pool_size",
"Configured pool size",
labels=["engine"],
)
for label, pool in self._pools:
checked_out.add_metric([label], pool.checkedout())
checked_in.add_metric([label], pool.checkedin())
overflow.add_metric([label], pool.overflow())
size.add_metric([label], pool.size())
return [checked_out, checked_in, overflow, size]
def describe(self) -> list[GaugeMetricFamily]:
# Return empty to mark this as an "unchecked" collector. Prometheus
# skips upfront descriptor validation and just calls collect() at
# scrape time. Required because our metrics are dynamic (engine
# labels depend on which engines are registered at runtime).
return []
def _register_pool_events(engine: Engine, label: str) -> None:
"""Attach pool event listeners for metrics collection.
Listens to checkout, checkin, connect, and invalidate events.
Stores per-connection metadata on connection_record.info for attribution.
"""
@event.listens_for(engine, "checkout")
def on_checkout(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
conn_proxy: PoolProxiedConnection, # noqa: ARG001
) -> None:
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
conn_record.info["_metrics_endpoint"] = handler
conn_record.info["_metrics_checkout_time"] = time.monotonic()
_checkout_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).inc()
@event.listens_for(engine, "checkin")
def on_checkin(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
) -> None:
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
_checkin_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler, engine=label).observe(
time.monotonic() - start
)
@event.listens_for(engine, "connect")
def on_connect(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry, # noqa: ARG001
) -> None:
_connections_created_total.labels(engine=label).inc()
@event.listens_for(engine, "invalidate")
def on_invalidate(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
exception: BaseException | None, # noqa: ARG001
) -> None:
_invalidations_total.labels(engine=label).inc()
# Defensively clean up the held-connections gauge in case checkin
# doesn't fire after invalidation (e.g. hard pool shutdown).
handler = conn_record.info.pop("_metrics_endpoint", None)
start = conn_record.info.pop("_metrics_checkout_time", None)
if handler:
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
time.monotonic() - start
)
def setup_postgres_connection_pool_metrics(
engines: dict[str, Engine | AsyncEngine],
) -> None:
"""Register pool metrics for all provided engines.
Args:
engines: Mapping of engine label to Engine or AsyncEngine.
Example: {"sync": sync_engine, "async": async_engine, "readonly": ro_engine}
Engines using NullPool are skipped (no pool state to monitor).
For AsyncEngine, events are registered on the underlying sync_engine.
"""
collector = PoolStateCollector()
for label, engine in engines.items():
# Resolve async engines to their underlying sync engine
sync_engine = engine.sync_engine if isinstance(engine, AsyncEngine) else engine
pool = sync_engine.pool
if not isinstance(pool, QueuePool):
logger.info(
f"Skipping pool metrics for engine '{label}' "
f"({type(pool).__name__} — no pool state)"
)
continue
collector.add_pool(label, pool)
_register_pool_events(sync_engine, label)
logger.info(f"Registered pool metrics for engine '{label}'")
REGISTRY.register(collector)

Some files were not shown because too many files have changed in this diff Show More