Compare commits

..

1 Commits

Author SHA1 Message Date
Jessica Singh
53a5ee2a6e stt and tts 2026-02-23 18:27:37 -08:00
264 changed files with 7885 additions and 10508 deletions

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,10 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -127,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,31 +0,0 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -0,0 +1,100 @@
"""add_voice_provider_and_user_voice_prefs
Revision ID: 93a2e195e25c
Revises: 631fd2504136
Create Date: 2026-02-23 15:16:39.507304
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "93a2e195e25c"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create voice_provider table
op.create_table(
"voice_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), unique=True, nullable=False),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
sa.Column("stt_model", sa.String(), nullable=True),
sa.Column("tts_model", sa.String(), nullable=True),
sa.Column("default_voice", sa.String(), nullable=True),
sa.Column(
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
nullable=False,
),
)
# Add voice preference columns to user table
op.add_column(
"user",
sa.Column(
"voice_auto_send",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_auto_playback",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_playback_speed",
sa.Float(),
default=1.0,
nullable=False,
server_default="1.0",
),
)
op.add_column(
"user",
sa.Column("preferred_voice", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove user voice preference columns
op.drop_column("user", "preferred_voice")
op.drop_column("user", "voice_playback_speed")
op.drop_column("user", "voice_auto_playback")
op.drop_column("user", "voice_auto_send")
# Drop voice_provider table
op.drop_table("voice_provider")

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -1,13 +1,9 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -47,27 +43,14 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
ctx = connector._create_rest_client_context(site_descriptor.url)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:

View File

@@ -1,13 +1,9 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -18,10 +14,7 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -40,70 +33,6 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -643,65 +572,8 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -757,22 +629,57 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
return external_user_groups
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -1671,10 +1650,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -13,7 +13,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -373,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -652,7 +651,6 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -763,7 +761,6 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -903,18 +900,6 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -10,7 +10,6 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -126,7 +125,6 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -134,8 +132,6 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None
@@ -190,7 +186,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "\n".join(sections)
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
@@ -228,21 +224,23 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -252,14 +250,12 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
system_prompt += INTERNAL_SEARCH_GUIDANCE
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -269,23 +265,20 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
)
if has_open_urls or include_all_guidance:
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
system_prompt += OPEN_URLS_GUIDANCE
if has_python or include_all_guidance:
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
system_prompt += PYTHON_TOOL_GUIDANCE
if has_generate_image or include_all_guidance:
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -284,9 +282,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
@@ -642,14 +637,6 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

File diff suppressed because it is too large Load Diff

View File

@@ -50,15 +50,12 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -66,15 +63,11 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
self.authority_host = authority_host.rstrip("/")
self.graph_api_host = graph_api_host.rstrip("/")
# impls for BaseConnector
@@ -83,7 +76,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{self.authority_host}/{teams_directory_id}"
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -98,7 +91,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=[f"{self.graph_api_host}/.default"]
scopes=["https://graph.microsoft.com/.default"]
)
if not isinstance(token, dict):

View File

@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session | None = None,
db_session: Session,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -928,7 +925,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session (optional if search_settings provided)
db_session: Database session
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1156,10 +1153,7 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -43,7 +41,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -51,8 +49,6 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
@@ -107,14 +103,9 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
@@ -261,15 +252,11 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -310,7 +297,6 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -329,8 +315,6 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -52,14 +50,9 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
db_session: Session,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
query_embedding = get_query_embedding(query_request.query, db_session)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -85,9 +78,7 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
db_session: Session,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -97,22 +88,14 @@ def search_chunks(
else None
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -131,10 +114,7 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
(_embed_and_search, (query_request, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
.options(selectinload(DocumentSetDBModel.federated_connectors))
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_documents_for_document_set_paginated(

View File

@@ -285,9 +285,15 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# Voice preferences
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
preferred_voice: Mapped[str | None] = mapped_column(String, nullable=True)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
"Credential", back_populates="user", lazy="joined"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,6 +327,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -2960,6 +2967,47 @@ class ImageGenerationConfig(Base):
)
class VoiceProvider(Base):
"""Configuration for voice services (STT and TTS)."""
__tablename__ = "voice_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider_type: Mapped[str] = mapped_column(
String
) # "openai", "azure", "elevenlabs"
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Model/voice configuration
stt_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "whisper-1"
tts_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "tts-1", "tts-1-hd"
default_voice: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "alloy", "echo"
# STT and TTS can use different providers - only one provider per type
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
@@ -4939,7 +4987,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4978,12 +5025,3 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
.limit(1)
)
if not user:
row = result.first()
if not row:
return None
_schedule_pat_last_used_update(hashed_token, now)
return user
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
asyncio.create_task(_update_last_used())
return user
def create_pat(

View File

@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.user),
)
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

251
backend/onyx/db/voice.py Normal file
View File

@@ -0,0 +1,251 @@
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import VoiceProvider
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
def fetch_voice_provider_by_type(
db_session: Session, provider_type: str
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
def upsert_voice_provider(
*,
db_session: Session,
provider_id: int | None,
name: str,
provider_type: str,
api_key: str | None,
api_key_changed: bool,
api_base: str | None = None,
custom_config: dict[str, Any] | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
activate_stt: bool = False,
activate_tts: bool = False,
) -> VoiceProvider:
"""Create or update a voice provider."""
provider: VoiceProvider | None = None
if provider_id is not None:
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
else:
provider = VoiceProvider()
db_session.add(provider)
# Apply updates
provider.name = name
provider.provider_type = provider_type
provider.api_base = api_base
provider.custom_config = custom_config
provider.stt_model = stt_model
provider.tts_model = tts_model
provider.default_voice = default_voice
# Only update API key if explicitly changed or if provider has no key
if api_key_changed or provider.api_key is None:
provider.api_key = api_key # type: ignore[assignment]
db_session.flush()
if activate_stt:
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
if activate_tts:
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
db_session.refresh(provider)
return provider
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
db_session.commit()
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Set a voice provider as the default STT provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other STT providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_stt.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_stt=False)
)
# Activate this provider
provider.is_default_stt = True
db_session.flush()
db_session.refresh(provider)
return provider
def set_default_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Set a voice provider as the default TTS provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other TTS providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_tts.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_tts=False)
)
# Activate this provider
provider.is_default_tts = True
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default STT status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_stt = False
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default TTS status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_tts = False
db_session.flush()
db_session.refresh(provider)
return provider
# User voice preferences
def update_user_voice_auto_send(
db_session: Session, user_id: UUID, auto_send: bool
) -> None:
"""Update user's voice auto-send setting."""
db_session.execute(
update(User).where(User.id == user_id).values(voice_auto_send=auto_send)
)
db_session.commit()
def update_user_voice_auto_playback(
db_session: Session, user_id: UUID, auto_playback: bool
) -> None:
"""Update user's voice auto-playback setting."""
db_session.execute(
update(User).where(User.id == user_id).values(voice_auto_playback=auto_playback)
)
db_session.commit()
def update_user_voice_playback_speed(
db_session: Session, user_id: UUID, speed: float
) -> None:
"""Update user's voice playback speed setting."""
# Clamp to valid range
speed = max(0.5, min(2.0, speed))
db_session.execute(
update(User).where(User.id == user_id).values(voice_playback_speed=speed)
)
db_session.commit()
def update_user_preferred_voice(
db_session: Session, user_id: UUID, voice: str | None
) -> None:
"""Update user's preferred voice setting."""
db_session.execute(
update(User).where(User.id == user_id).values(preferred_voice=voice)
)
db_session.commit()
def update_user_voice_settings(
db_session: Session,
user_id: UUID,
auto_send: bool | None = None,
auto_playback: bool | None = None,
playback_speed: float | None = None,
preferred_voice: str | None = None,
) -> None:
"""Update user's voice settings. Only updates fields that are not None."""
values: dict[str, Any] = {}
if auto_send is not None:
values["voice_auto_send"] = auto_send
if auto_playback is not None:
values["voice_auto_playback"] = auto_playback
if playback_speed is not None:
values["voice_playback_speed"] = max(0.5, min(2.0, playback_speed))
if preferred_voice is not None:
values["preferred_voice"] = preferred_voice
if values:
db_session.execute(update(User).where(User.id == user_id).values(**values))
db_session.commit()

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,7 +1,5 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def _load_bundled_recommendations() -> LLMRecommendations:
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def is_obsolete_model(model_name: str, provider: str) -> bool:

View File

@@ -112,6 +112,8 @@ from onyx.server.manage.opensearch_migration.api import (
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.voice.api import admin_router as voice_admin_router
from onyx.server.manage.voice.user_api import router as voice_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
@@ -428,6 +430,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(application, voice_admin_router)
include_router_with_global_prefix_prepended(application, voice_router)
include_router_with_global_prefix_prepended(
application, opensearch_migration_admin_router
)

View File

@@ -69,12 +69,6 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -1,6 +1,6 @@
# ruff: noqa: E501, W605 start
# If there are any tools, this section is included, the sections below are for the available tools
TOOL_SECTION_HEADER = "\n# Tools\n\n"
TOOL_SECTION_HEADER = "\n\n# Tools\n"
# This section is included if there are search type tools, currently internal_search and web_search
@@ -16,10 +16,11 @@ When searching for information, if the initial results cannot fully answer the u
Do not repeat the same or very similar queries if it already has been run in the chat history.
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
""".lstrip()
"""
INTERNAL_SEARCH_GUIDANCE = """
## internal_search
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
@@ -27,31 +28,34 @@ Use the `internal_search` tool to search connected applications for information.
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
- Ambiguity: questions about something that is not widely known or understood.
Never provide more than 3 queries at once to `internal_search`.
""".lstrip()
"""
WEB_SEARCH_GUIDANCE = """
## web_search
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
- Accuracy: if the cost of outdated/inaccurate information is high.
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
""".lstrip()
"""
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
Do not use the "site:" operator in your web search queries.
""".lstrip()
""".rstrip()
OPEN_URLS_GUIDANCE = """
## open_url
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
Do not open URLs that are image files like .png, .jpg, etc.
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
""".lstrip()
"""
PYTHON_TOOL_GUIDANCE = """
## python
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
@@ -60,21 +64,21 @@ Use this to give the user a way to download the file OR to display generated ima
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
""".lstrip()
"""
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
""".lstrip()
"""
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
""".lstrip()
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.

View File

@@ -1,36 +1,40 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n# User Information\n\n"
USER_INFORMATION_HEADER = "\n\n# User Information\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
""".lstrip()
"""
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
""".lstrip()
"""
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
""".lstrip()
"""
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
""".lstrip()
"""
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
""".lstrip()
"""
# ruff: noqa: E501, W605 end

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -988,7 +987,6 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1002,23 +1000,11 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
# Get most recent finished index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
]
if user and user.role == UserRole.ADMIN:
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
),
)
@@ -1939,7 +1911,6 @@ Tenant ID: {tenant_id}
class BasicCCPairInfo(BaseModel):
has_successful_run: bool
source: DocumentSource
status: ConnectorCredentialPairStatus
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
@@ -1953,17 +1924,13 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
status=cc_pair.status,
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
cc_pair_model.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -111,8 +111,7 @@ class DocumentSet(BaseModel):
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=[cc_pair.credential_id],
cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential

View File

@@ -73,6 +73,12 @@ class UserPreferences(BaseModel):
chat_background: str | None = None
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
# Voice preferences
voice_auto_send: bool | None = None
voice_auto_playback: bool | None = None
voice_playback_speed: float | None = None
preferred_voice: str | None = None
# controls which tools are enabled for the user for a specific assistant
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
@@ -152,6 +158,10 @@ class UserInfo(BaseModel):
theme_preference=user.theme_preference,
chat_background=user.chat_background,
default_app_mode=user.default_app_mode,
voice_auto_send=user.voice_auto_send,
voice_auto_playback=user.voice_auto_playback,
voice_playback_speed=user.voice_playback_speed,
preferred_voice=user.preferred_voice,
assistant_specific_configs=assistant_specific_configs,
)
),
@@ -228,6 +238,13 @@ class ChatBackgroundRequest(BaseModel):
chat_background: str | None
class VoiceSettingsUpdateRequest(BaseModel):
auto_send: bool | None = None
auto_playback: bool | None = None
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
preferred_voice: str | None = None
class PersonalizationUpdateRequest(BaseModel):
name: str | None = None
role: str | None = None

View File

@@ -608,8 +608,7 @@ def list_all_users_basic_info(
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
if include_api_keys or not is_api_key_email_address(user.email)
]

View File

@@ -0,0 +1,232 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.voice import deactivate_stt_provider
from onyx.db.voice import deactivate_tts_provider
from onyx.db.voice import delete_voice_provider
from onyx.db.voice import fetch_voice_provider_by_id
from onyx.db.voice import fetch_voice_provider_by_type
from onyx.db.voice import fetch_voice_providers
from onyx.db.voice import set_default_stt_provider
from onyx.db.voice import set_default_tts_provider
from onyx.db.voice import upsert_voice_provider
from onyx.server.manage.voice.models import VoiceProviderTestRequest
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
from onyx.server.manage.voice.models import VoiceProviderView
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/voice")
def _provider_to_view(provider) -> VoiceProviderView:
"""Convert a VoiceProvider model to a VoiceProviderView."""
return VoiceProviderView(
id=provider.id,
name=provider.name,
provider_type=provider.provider_type,
is_default_stt=provider.is_default_stt,
is_default_tts=provider.is_default_tts,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
has_api_key=bool(provider.api_key),
)
@admin_router.get("/providers")
def list_voice_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VoiceProviderView]:
"""List all configured voice providers."""
providers = fetch_voice_providers(db_session)
return [_provider_to_view(provider) for provider in providers]
@admin_router.post("/providers")
def upsert_voice_provider_endpoint(
request: VoiceProviderUpsertRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Create or update a voice provider."""
provider = upsert_voice_provider(
db_session=db_session,
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=request.api_key,
api_key_changed=request.api_key_changed,
api_base=request.api_base,
custom_config=request.custom_config,
stt_model=request.stt_model,
tts_model=request.tts_model,
default_voice=request.default_voice,
activate_stt=request.activate_stt,
activate_tts=request.activate_tts,
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.delete(
"/providers/{provider_id}", status_code=204, response_class=Response
)
def delete_voice_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
"""Delete a voice provider."""
delete_voice_provider(db_session, provider_id)
return Response(status_code=204)
@admin_router.post("/providers/{provider_id}/activate-stt")
def activate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default STT provider."""
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-stt")
def deactivate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default STT status from a voice provider."""
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/{provider_id}/activate-tts")
def activate_tts_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default TTS provider."""
provider = set_default_tts_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-tts")
def deactivate_tts_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default TTS status from a voice provider."""
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/test")
def test_voice_provider(
request: VoiceProviderTestRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Test a voice provider connection."""
api_key = request.api_key
if request.use_stored_key:
existing_provider = fetch_voice_provider_by_type(
db_session, request.provider_type
)
if existing_provider is None or not existing_provider.api_key:
raise HTTPException(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise HTTPException(
status_code=400,
detail="API key is required. Either provide api_key or set use_stored_key to true.",
)
try:
provider = get_voice_provider(
provider_type=request.provider_type,
api_key=api_key,
api_base=request.api_base,
custom_config=request.custom_config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400, detail="Unable to build provider configuration."
)
# Test the provider by getting available voices (lightweight check)
try:
voices = provider.get_available_voices()
if not voices:
raise HTTPException(
status_code=400,
detail="Provider returned no available voices.",
)
except NotImplementedError:
# Provider not fully implemented yet (Azure, ElevenLabs placeholders)
pass
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Connection test failed: {str(e)}",
) from e
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
return {"status": "ok"}
@admin_router.get("/providers/{provider_id}/voices")
def get_provider_voices(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[dict[str, str]]:
"""Get available voices for a provider."""
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
if provider_db is None:
raise HTTPException(status_code=404, detail="Voice provider not found.")
if not provider_db.api_key:
raise HTTPException(
status_code=400, detail="Provider has no API key configured."
)
try:
provider = get_voice_provider(
provider_type=provider_db.provider_type,
api_key=provider_db.api_key.get_value(apply_mask=False),
api_base=provider_db.api_base,
custom_config=provider_db.custom_config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return provider.get_available_voices()

View File

@@ -0,0 +1,74 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
class VoiceProviderView(BaseModel):
"""Response model for voice provider listing."""
id: int
name: str
provider_type: str # "openai", "azure", "elevenlabs"
is_default_stt: bool
is_default_tts: bool
stt_model: str | None
tts_model: str | None
default_voice: str | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
class VoiceProviderUpsertRequest(BaseModel):
"""Request model for creating or updating a voice provider."""
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: str # "openai", "azure", "elevenlabs"
api_key: str | None = Field(
default=None,
description="API key for the provider.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
api_base: str | None = None
custom_config: dict[str, Any] | None = None
stt_model: str | None = None
tts_model: str | None = None
default_voice: str | None = None
activate_stt: bool = Field(
default=False,
description="If true, sets this provider as the default STT provider after upsert.",
)
activate_tts: bool = Field(
default=False,
description="If true, sets this provider as the default TTS provider after upsert.",
)
class VoiceProviderTestRequest(BaseModel):
"""Request model for testing a voice provider connection."""
provider_type: str
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type.",
)
api_base: str | None = None
custom_config: dict[str, Any] | None = None
class SynthesizeRequest(BaseModel):
"""Request model for text-to-speech synthesis."""
text: str = Field(..., min_length=1, max_length=4096)
voice: str | None = None
speed: float = Field(default=1.0, ge=0.5, le=2.0)

View File

@@ -0,0 +1,160 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.db.voice import update_user_voice_settings
from onyx.server.manage.models import VoiceSettingsUpdateRequest
from onyx.server.manage.voice.models import SynthesizeRequest
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Max audio file size: 25MB (Whisper limit)
MAX_AUDIO_SIZE = 25 * 1024 * 1024
@router.post("/transcribe")
async def transcribe_audio(
audio: UploadFile = File(...),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Transcribe audio to text using the default STT provider."""
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
raise HTTPException(
status_code=400,
detail="No speech-to-text provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
audio_data = await audio.read()
if len(audio_data) > MAX_AUDIO_SIZE:
raise HTTPException(
status_code=400,
detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
)
# Extract format from filename
filename = audio.filename or "audio.webm"
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
try:
provider = get_voice_provider(
provider_type=provider_db.provider_type,
api_key=provider_db.api_key.get_value(apply_mask=False),
api_base=provider_db.api_base,
custom_config=provider_db.custom_config or {},
stt_model=provider_db.stt_model,
tts_model=provider_db.tts_model,
default_voice=provider_db.default_voice,
)
except ValueError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
try:
text = await provider.transcribe(audio_data, audio_format)
return {"text": text}
except NotImplementedError as exc:
raise HTTPException(
status_code=501,
detail=f"Speech-to-text not implemented for {provider_db.provider_type}.",
) from exc
except Exception as exc:
logger.error(f"Transcription failed: {exc}")
raise HTTPException(
status_code=500,
detail=f"Transcription failed: {str(exc)}",
) from exc
@router.post("/synthesize")
async def synthesize_speech(
request: SynthesizeRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""Synthesize text to speech using the default TTS provider."""
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
raise HTTPException(
status_code=400,
detail="No text-to-speech provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
# Use request voice, or user's preferred voice, or provider default
voice = request.voice or user.preferred_voice or provider_db.default_voice
speed = request.speed or user.voice_playback_speed or 1.0
try:
provider = get_voice_provider(
provider_type=provider_db.provider_type,
api_key=provider_db.api_key.get_value(apply_mask=False),
api_base=provider_db.api_base,
custom_config=provider_db.custom_config or {},
stt_model=provider_db.stt_model,
tts_model=provider_db.tts_model,
default_voice=provider_db.default_voice,
)
except ValueError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
async def audio_stream():
try:
async for chunk in provider.synthesize_stream(
text=request.text, voice=voice, speed=speed
):
yield chunk
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={"Content-Disposition": "inline; filename=speech.mp3"},
)
@router.patch("/settings")
def update_voice_settings(
request: VoiceSettingsUpdateRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Update user's voice settings."""
update_user_voice_settings(
db_session=db_session,
user_id=user.id,
auto_send=request.auto_send,
auto_playback=request.auto_playback,
playback_speed=request.playback_speed,
preferred_voice=request.preferred_voice,
)
return {"status": "ok"}

View File

@@ -36,8 +36,6 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
@@ -52,7 +50,6 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.utils.logger import setup_logger
@@ -380,37 +377,6 @@ def create_memory_packets(
return packets
def create_python_tool_packets(
code: str,
stdout: str,
stderr: str,
file_ids: list[str],
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
tool call data so the frontend can display both the code and its output
on page reload."""
packets: list[Packet] = []
placement = Placement(turn_index=turn_index, tab_index=tab_index)
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
packets.append(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
),
)
)
packets.append(Packet(placement=placement, obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
search_docs: list[SavedSearchDoc],
@@ -620,41 +586,6 @@ def translate_assistant_message_to_packets(
)
)
elif tool.in_code_tool_id == PythonTool.__name__:
code = cast(
str,
tool_call.tool_call_arguments.get("code", ""),
)
stdout = ""
stderr = ""
file_ids: list[str] = []
if tool_call.tool_call_response:
try:
response_data = json.loads(tool_call.tool_call_response)
stdout = response_data.get("stdout", "")
stderr = response_data.get("stderr", "")
generated_files = response_data.get(
"generated_files", []
)
file_ids = [
f.get("file_link", "").split("/")[-1]
for f in generated_files
if f.get("file_link")
]
except (json.JSONDecodeError, KeyError):
# Fall back to raw response as stdout
stdout = tool_call.tool_call_response
turn_tool_packets.extend(
create_python_tool_packets(
code=code,
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
turn_index=turn_num,
tab_index=tool_call.tab_index,
)
)
else:
# Custom tool or unknown tool
turn_tool_packets.extend(

View File

@@ -24,7 +24,6 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SAML_CONF_DIR
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
@@ -124,12 +123,9 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
if request.client is None:
raise ValueError("Invalid request for SAML")
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
# to poison SAML redirect URLs (host header poisoning).
parsed_domain = urlparse(WEB_DOMAIN)
http_host = parsed_domain.hostname or request.client.host
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
# Use X-Forwarded headers if available
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
rv: dict[str, Any] = {
"http_host": http_host,

View File

@@ -55,9 +55,7 @@ class Settings(BaseModel):
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status

View File

@@ -199,12 +199,6 @@ class PythonToolOverrideKwargs(BaseModel):
chat_files: list[ChatFile] = []
class ImageGenerationToolOverrideKwargs(BaseModel):
"""Override kwargs for image generation tool calls."""
recent_generated_image_file_ids: list[str] = []
class SearchToolRunContext(BaseModel):
emitter: Emitter

View File

@@ -171,8 +171,10 @@ def construct_tools(
if not search_tool_config:
search_tool_config = SearchToolConfig()
# TODO concerning passing the db_session here.
search_tool = SearchTool(
tool_id=db_tool_model.id,
db_session=db_session,
emitter=emitter,
user=user,
persona=persona,
@@ -420,6 +422,7 @@ def construct_tools(
search_tool = SearchTool(
tool_id=search_tool_db_model.id,
db_session=db_session,
emitter=emitter,
user=user,
persona=persona,

View File

@@ -11,14 +11,11 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.app_configs import IMAGE_MODEL_PROVIDER
from onyx.db.image_generation import get_default_image_generation_config
from onyx.file_store.models import ChatFileType
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import load_chat_file_by_id
from onyx.file_store.utils import save_files
from onyx.image_gen.factory import get_image_generation_provider
from onyx.image_gen.factory import validate_credentials
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
@@ -26,7 +23,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
@@ -35,7 +31,6 @@ from onyx.tools.tool_implementations.images.models import (
)
from onyx.tools.tool_implementations.images.models import ImageGenerationResponse
from onyx.tools.tool_implementations.images.models import ImageShape
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -45,10 +40,10 @@ logger = setup_logger()
HEARTBEAT_INTERVAL = 5.0
PROMPT_FIELD = "prompt"
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
NAME = "generate_image"
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
DISPLAY_NAME = "Image Generation"
@@ -64,7 +59,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
) -> None:
super().__init__(emitter=emitter)
self.model = model
self.provider = provider
self.num_imgs = num_imgs
self.img_provider = get_image_generation_provider(
@@ -139,16 +133,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
),
"enum": [shape.value for shape in ImageShape],
},
REFERENCE_IMAGE_FILE_IDS_FIELD: {
"type": "array",
"description": (
"Optional image file IDs to use as reference context for edits/variations. "
"Use the file_id values returned by previous generate_image calls."
),
"items": {
"type": "string",
},
},
},
"required": [PROMPT_FIELD],
},
@@ -164,10 +148,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
def _generate_image(
self,
prompt: str,
shape: ImageShape,
reference_images: list[ReferenceImage] | None = None,
self, prompt: str, shape: ImageShape
) -> tuple[ImageGenerationResponse, Any]:
if shape == ImageShape.LANDSCAPE:
if "gpt-image-1" in self.model:
@@ -188,7 +169,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
model=self.model,
size=size,
n=1,
reference_images=reference_images,
# response_format parameter is not supported for gpt-image-1
response_format=None if "gpt-image-1" in self.model else "b64_json",
)
@@ -251,117 +231,10 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
emit_error_packet=True,
)
def _resolve_reference_image_file_ids(
self,
llm_kwargs: dict[str, Any],
override_kwargs: ImageGenerationToolOverrideKwargs | None,
) -> list[str]:
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
if raw_reference_ids is not None:
if not isinstance(raw_reference_ids, list) or not all(
isinstance(file_id, str) for file_id in raw_reference_ids
):
raise ToolCallException(
message=(
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, "
f"got {type(raw_reference_ids)}"
),
llm_facing_message=(
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
),
)
reference_image_file_ids = [
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
]
elif (
override_kwargs
and override_kwargs.recent_generated_image_file_ids
and self.img_provider.supports_reference_images
):
# If no explicit reference was provided, default to the most recently generated image.
reference_image_file_ids = [
override_kwargs.recent_generated_image_file_ids[-1]
]
else:
reference_image_file_ids = []
# Deduplicate while preserving order.
deduped_reference_image_ids: list[str] = []
seen_ids: set[str] = set()
for file_id in reference_image_file_ids:
if file_id in seen_ids:
continue
seen_ids.add(file_id)
deduped_reference_image_ids.append(file_id)
if not deduped_reference_image_ids:
return []
if not self.img_provider.supports_reference_images:
raise ToolCallException(
message=(
f"Reference images requested but provider '{self.provider}' "
"does not support image-editing context."
),
llm_facing_message=(
"This image provider does not support editing from previous image context. "
"Try text-only generation, or switch to a provider/model that supports image edits."
),
)
max_reference_images = self.img_provider.max_reference_images
if max_reference_images > 0:
return deduped_reference_image_ids[-max_reference_images:]
return deduped_reference_image_ids
def _load_reference_images(
self,
reference_image_file_ids: list[str],
) -> list[ReferenceImage]:
reference_images: list[ReferenceImage] = []
for file_id in reference_image_file_ids:
try:
loaded_file = load_chat_file_by_id(file_id)
except Exception as e:
raise ToolCallException(
message=f"Could not load reference image file '{file_id}': {e}",
llm_facing_message=(
f"Reference image file '{file_id}' could not be loaded. "
"Use file_id values returned by previous generate_image calls."
),
)
if loaded_file.file_type != ChatFileType.IMAGE:
raise ToolCallException(
message=f"Reference file '{file_id}' is not an image",
llm_facing_message=f"Reference file '{file_id}' is not an image.",
)
try:
mime_type = get_image_type_from_bytes(loaded_file.content)
except Exception as e:
raise ToolCallException(
message=f"Unsupported reference image format for '{file_id}': {e}",
llm_facing_message=(
f"Reference image '{file_id}' has an unsupported format. "
"Only PNG, JPEG, GIF, and WEBP are supported."
),
)
reference_images.append(
ReferenceImage(
data=loaded_file.content,
mime_type=mime_type,
)
)
return reference_images
def run(
self,
placement: Placement,
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
override_kwargs: None = None, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
if PROMPT_FIELD not in llm_kwargs:
@@ -374,11 +247,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
prompt = cast(str, llm_kwargs[PROMPT_FIELD])
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
reference_image_file_ids = self._resolve_reference_image_file_ids(
llm_kwargs=llm_kwargs,
override_kwargs=override_kwargs,
)
reference_images = self._load_reference_images(reference_image_file_ids)
# Use threading to generate images in parallel while emitting heartbeats
results: list[tuple[ImageGenerationResponse, Any] | None] = [
@@ -399,7 +267,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
(
prompt,
shape,
reference_images or None,
),
)
for _ in range(self.num_imgs)
@@ -480,7 +347,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
llm_facing_response = json.dumps(
[
{
"file_id": img.file_id,
"revised_prompt": img.revised_prompt,
}
for img in generated_images_metadata

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
@@ -22,22 +21,10 @@ from onyx.utils.web_content import title_from_url
logger = setup_logger()
DEFAULT_READ_TIMEOUT_SECONDS = 15
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
DEFAULT_MAX_WORKERS = 5
def _failed_result(url: str) -> WebContent:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
class OnyxWebCrawler(WebContentProvider):
@@ -50,14 +37,12 @@ class OnyxWebCrawler(WebContentProvider):
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
max_pdf_size_bytes: int | None = None,
max_html_size_bytes: int | None = None,
) -> None:
self._read_timeout_seconds = timeout_seconds
self._connect_timeout_seconds = connect_timeout_seconds
self._timeout_seconds = timeout_seconds
self._max_pdf_size_bytes = max_pdf_size_bytes
self._max_html_size_bytes = max_html_size_bytes
self._headers = {
@@ -66,68 +51,75 @@ class OnyxWebCrawler(WebContentProvider):
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._fetch_url_safe, urls))
def _fetch_url_safe(self, url: str) -> WebContent:
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
try:
return self._fetch_url(url)
except Exception as exc:
logger.warning(
"Onyx crawler unexpected error for %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
def _fetch_url(self, url: str) -> WebContent:
try:
# Use SSRF-safe request to prevent DNS rebinding attacks
response = ssrf_safe_get(
url,
headers=self._headers,
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
url, headers=self._headers, timeout=self._timeout_seconds
)
except SSRFException as exc:
logger.error(
"SSRF protection blocked request to %s (%s)",
"SSRF protection blocked request to %s: %s",
url,
exc.__class__.__name__,
str(exc),
)
return _failed_result(url)
except Exception as exc:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
except Exception as exc: # pragma: no cover - network failures vary
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
content_type = response.headers.get("Content-Type", "")
content = response.content
content_sniff = content[:1024] if content else None
content_sniff = response.content[:1024] if response.content else None
if is_pdf_resource(url, content_type, content_sniff):
if (
self._max_pdf_size_bytes is not None
and len(content) > self._max_pdf_size_bytes
and len(response.content) > self._max_pdf_size_bytes
):
logger.warning(
"PDF content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_pdf_size_bytes,
)
return _failed_result(url)
text_content, metadata = extract_pdf_text(content)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
text_content, metadata = extract_pdf_text(response.content)
title = title_from_pdf_metadata(metadata) or title_from_url(url)
return WebContent(
title=title,
@@ -139,19 +131,25 @@ class OnyxWebCrawler(WebContentProvider):
if (
self._max_html_size_bytes is not None
and len(content) > self._max_html_size_bytes
and len(response.content) > self._max_html_size_bytes
):
logger.warning(
"HTML content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_html_size_bytes,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
try:
decoded_html = decode_html_bytes(
content,
response.content,
content_type=content_type,
fallback_encoding=response.apparent_encoding or response.encoding,
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,260 +0,0 @@
from __future__ import annotations
from typing import Any
import requests
from fastapi import HTTPException
from onyx.tools.tool_implementations.web_search.models import (
WebSearchProvider,
)
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
BRAVE_MAX_RESULTS_PER_REQUEST = 20
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
class RetryableBraveSearchError(Exception):
"""Error type used to trigger retry for transient Brave search failures."""
class BraveClient(WebSearchProvider):
def __init__(
self,
api_key: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
country: str | None = None,
search_lang: str | None = None,
ui_lang: str | None = None,
safesearch: str | None = None,
freshness: str | None = None,
) -> None:
if timeout_seconds <= 0:
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
self._headers = {
"Accept": "application/json",
"X-Subscription-Token": api_key,
}
logger.debug(f"Count of results passed to BraveClient: {num_results}")
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
self._timeout_seconds = timeout_seconds
self._country = _normalize_country(country)
self._search_lang = _normalize_language_code(
search_lang, field_name="search_lang"
)
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
self._safesearch = _normalize_option(
safesearch,
field_name="safesearch",
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
)
self._freshness = _normalize_option(
freshness,
field_name="freshness",
allowed_values=BRAVE_FRESHNESS_OPTIONS,
)
def _build_search_params(self, query: str) -> dict[str, str]:
params = {
"q": query,
"count": str(self._num_results),
}
if self._country:
params["country"] = self._country
if self._search_lang:
params["search_lang"] = self._search_lang
if self._ui_lang:
params["ui_lang"] = self._ui_lang
if self._safesearch:
params["safesearch"] = self._safesearch
if self._freshness:
params["freshness"] = self._freshness
return params
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(RetryableBraveSearchError,),
)
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
params = self._build_search_params(query)
try:
response = requests.get(
BRAVE_WEB_SEARCH_URL,
headers=self._headers,
params=params,
timeout=self._timeout_seconds,
)
except requests.RequestException as exc:
raise RetryableBraveSearchError(
f"Brave search request failed: {exc}"
) from exc
try:
response.raise_for_status()
except requests.HTTPError as exc:
error_msg = _build_error_message(response)
if _is_retryable_status(response.status_code):
raise RetryableBraveSearchError(error_msg) from exc
raise ValueError(error_msg) from exc
data = response.json()
web_results = (data.get("web") or {}).get("results") or []
results: list[WebSearchResult] = []
for result in web_results:
if not isinstance(result, dict):
continue
link = _clean_string(result.get("url"))
if not link:
continue
title = _clean_string(result.get("title"))
description = _clean_string(result.get("description"))
results.append(
WebSearchResult(
title=title,
link=link,
snippet=description,
author=None,
published_date=None,
)
)
return results
def search(self, query: str) -> list[WebSearchResult]:
try:
return self._search_with_retries(query)
except RetryableBraveSearchError as exc:
raise ValueError(str(exc)) from exc
def test_connection(self) -> dict[str, str]:
try:
test_results = self.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="Brave API key validation failed: search returned no results.",
)
except HTTPException:
raise
except (ValueError, requests.RequestException) as e:
error_msg = str(e)
lower = error_msg.lower()
if (
"status 401" in lower
or "status 403" in lower
or "api key" in lower
or "auth" in lower
):
raise HTTPException(
status_code=400,
detail=f"Invalid Brave API key: {error_msg}",
) from e
if "status 429" in lower or "rate limit" in lower:
raise HTTPException(
status_code=400,
detail=f"Brave API rate limit exceeded: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"Brave API key validation failed: {error_msg}",
) from e
logger.info("Web search provider test succeeded for Brave.")
return {"status": "ok"}
def _build_error_message(response: requests.Response) -> str:
return (
"Brave search failed "
f"(status {response.status_code}): {_extract_error_detail(response)}"
)
def _extract_error_detail(response: requests.Response) -> str:
try:
payload: Any = response.json()
except Exception:
text = response.text.strip()
return text[:200] if text else "No error details"
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
detail = error.get("detail") or error.get("message")
if isinstance(detail, str):
return detail
if isinstance(error, str):
return error
message = payload.get("message")
if isinstance(message, str):
return message
return str(payload)[:200]
def _is_retryable_status(status_code: int) -> bool:
return status_code == 429 or status_code >= 500
def _clean_string(value: Any) -> str:
return value.strip() if isinstance(value, str) else ""
def _normalize_country(country: str | None) -> str | None:
if country is None:
return None
normalized = country.strip().upper()
if not normalized:
return None
if len(normalized) != 2 or not normalized.isalpha():
raise ValueError(
"Brave provider config 'country' must be a 2-letter ISO country code."
)
return normalized
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 20:
raise ValueError(f"Brave provider config '{field_name}' is too long.")
return normalized
def _normalize_option(
value: str | None,
*,
field_name: str,
allowed_values: set[str],
) -> str | None:
if value is None:
return None
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in allowed_values:
allowed = ", ".join(sorted(allowed_values))
raise ValueError(
f"Brave provider config '{field_name}' must be one of: {allowed}."
)
return normalized

View File

@@ -13,9 +13,6 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
DEFAULT_MAX_PDF_SIZE_BYTES,
)
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
BraveClient,
)
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
ExaClient,
)
@@ -38,28 +35,6 @@ from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def _parse_positive_int_config(
*,
raw_value: str | None,
default: int,
provider_name: str,
config_key: str,
) -> int:
if not raw_value:
return default
try:
value = int(raw_value)
except ValueError as exc:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be an integer."
) from exc
if value <= 0:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be greater than 0."
)
return value
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
"""Return True if the given provider type requires an API key.
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
@@ -92,22 +67,6 @@ def build_search_provider_from_config(
if provider_type == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.BRAVE:
return BraveClient(
api_key=api_key,
num_results=num_results,
timeout_seconds=_parse_positive_int_config(
raw_value=config.get("timeout_seconds"),
default=10,
provider_name="Brave",
config_key="timeout_seconds",
),
country=config.get("country"),
search_lang=config.get("search_lang"),
ui_lang=config.get("ui_lang"),
safesearch=config.get("safesearch"),
freshness=config.get("freshness"),
)
if provider_type == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.GOOGLE_PSE:

View File

@@ -1,4 +1,3 @@
import json
import traceback
from collections import defaultdict
from typing import Any
@@ -14,7 +13,6 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import OpenURLToolOverrideKwargs
from onyx.tools.models import ParallelToolCallResponse
from onyx.tools.models import PythonToolOverrideKwargs
@@ -24,9 +22,6 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
@@ -110,63 +105,6 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
return merged_calls
def _extract_image_file_ids_from_tool_response_message(
message: str,
) -> list[str]:
try:
parsed_message = json.loads(message)
except json.JSONDecodeError:
return []
parsed_items: list[Any] = (
parsed_message if isinstance(parsed_message, list) else [parsed_message]
)
file_ids: list[str] = []
for item in parsed_items:
if not isinstance(item, dict):
continue
file_id = item.get("file_id")
if isinstance(file_id, str):
file_ids.append(file_id)
return file_ids
def _extract_recent_generated_image_file_ids(
message_history: list[ChatMessageSimple],
) -> list[str]:
tool_name_by_tool_call_id: dict[str, str] = {}
recent_image_file_ids: list[str] = []
seen_file_ids: set[str] = set()
for message in message_history:
if message.message_type == MessageType.ASSISTANT and message.tool_calls:
for tool_call in message.tool_calls:
tool_name_by_tool_call_id[tool_call.tool_call_id] = tool_call.tool_name
continue
if (
message.message_type != MessageType.TOOL_CALL_RESPONSE
or not message.tool_call_id
):
continue
tool_name = tool_name_by_tool_call_id.get(message.tool_call_id)
if tool_name != ImageGenerationTool.NAME:
continue
for file_id in _extract_image_file_ids_from_tool_response_message(
message.message
):
if file_id in seen_file_ids:
continue
seen_file_ids.add(file_id)
recent_image_file_ids.append(file_id)
return recent_image_file_ids
def _safe_run_single_tool(
tool: Tool,
tool_call: ToolCallKickoff,
@@ -386,9 +324,6 @@ def run_tool_calls(
url_to_citation: dict[str, int] = {
url: citation_num for citation_num, url in citation_mapping.items()
}
recent_generated_image_file_ids = _extract_recent_generated_image_file_ids(
message_history
)
# Prepare all tool calls with their override_kwargs
# Each tool gets a unique starting citation number to avoid conflicts when running in parallel
@@ -405,7 +340,6 @@ def run_tool_calls(
| WebSearchToolOverrideKwargs
| OpenURLToolOverrideKwargs
| PythonToolOverrideKwargs
| ImageGenerationToolOverrideKwargs
| MemoryToolOverrideKwargs
| None
) = None
@@ -454,10 +388,6 @@ def run_tool_calls(
override_kwargs = PythonToolOverrideKwargs(
chat_files=chat_files or [],
)
elif isinstance(tool, ImageGenerationTool):
override_kwargs = ImageGenerationToolOverrideKwargs(
recent_generated_image_file_ids=recent_generated_image_file_ids
)
elif isinstance(tool, MemoryTool):
override_kwargs = MemoryToolOverrideKwargs(
user_name=(

View File

@@ -146,7 +146,7 @@ MAX_REDIRECTS = 10
def _make_ssrf_safe_request(
url: str,
headers: dict[str, str] | None = None,
timeout: float | tuple[float, float] = 15,
timeout: int = 15,
**kwargs: Any,
) -> requests.Response:
"""
@@ -204,7 +204,7 @@ def _make_ssrf_safe_request(
def ssrf_safe_get(
url: str,
headers: dict[str, str] | None = None,
timeout: float | tuple[float, float] = 15,
timeout: int = 15,
follow_redirects: bool = True,
**kwargs: Any,
) -> requests.Response:

View File

View File

@@ -0,0 +1,54 @@
from onyx.db.models import VoiceProvider
from onyx.voice.interface import VoiceProviderInterface
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
"""
Factory function to get the appropriate voice provider implementation.
Args:
provider: VoiceProvider database model instance
Returns:
VoiceProviderInterface implementation
Raises:
ValueError: If provider_type is not supported
"""
provider_type = provider.provider_type.lower()
if provider_type == "openai":
from onyx.voice.providers.openai import OpenAIVoiceProvider
return OpenAIVoiceProvider(
api_key=provider.api_key,
api_base=provider.api_base,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
)
elif provider_type == "azure":
from onyx.voice.providers.azure import AzureVoiceProvider
return AzureVoiceProvider(
api_key=provider.api_key,
custom_config=provider.custom_config or {},
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
)
elif provider_type == "elevenlabs":
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
return ElevenLabsVoiceProvider(
api_key=provider.api_key,
api_base=provider.api_base,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
)
else:
raise ValueError(f"Unsupported voice provider type: {provider_type}")

View File

@@ -0,0 +1,65 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import AsyncIterator
class VoiceProviderInterface(ABC):
"""Abstract base class for voice providers (STT and TTS)."""
@abstractmethod
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Convert audio to text (Speech-to-Text).
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
@abstractmethod
async def synthesize_stream(
self, text: str, voice: str, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio stream (Text-to-Speech).
Streams audio chunks progressively for lower latency playback.
Args:
text: Text to convert to speech
voice: Voice identifier (e.g., "alloy", "echo")
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks
"""
@abstractmethod
def get_available_voices(self) -> list[dict[str, str]]:
"""
Get list of available voices for this provider.
Returns:
List of voice dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_stt_models(self) -> list[dict[str, str]]:
"""
Get list of available STT models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_tts_models(self) -> list[dict[str, str]]:
"""
Get list of available TTS models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""

View File

View File

@@ -0,0 +1,51 @@
from collections.abc import AsyncIterator
from typing import Any
from onyx.voice.interface import VoiceProviderInterface
class AzureVoiceProvider(VoiceProviderInterface):
"""Azure Speech Services voice provider (placeholder - to be implemented)."""
def __init__(
self,
api_key: str | None,
custom_config: dict[str, Any],
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.custom_config = custom_config
self.speech_region = custom_config.get("speech_region", "")
self.stt_model = stt_model
self.tts_model = tts_model
self.default_voice = default_voice or "en-US-JennyNeural"
async def transcribe(self, _audio_data: bytes, _audio_format: str) -> str:
raise NotImplementedError("Azure STT not yet implemented")
async def synthesize_stream(
self, _text: str, _voice: str | None = None, _speed: float = 1.0
) -> AsyncIterator[bytes]:
raise NotImplementedError("Azure TTS not yet implemented")
yield b"" # Required for async generator
def get_available_voices(self) -> list[dict[str, str]]:
# Azure has many voices - return common ones
return [
{"id": "en-US-JennyNeural", "name": "Jenny (US)"},
{"id": "en-US-GuyNeural", "name": "Guy (US)"},
{"id": "en-GB-SoniaNeural", "name": "Sonia (UK)"},
{"id": "en-GB-RyanNeural", "name": "Ryan (UK)"},
]
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "default", "name": "Azure Speech Recognition"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "neural", "name": "Neural TTS"},
]

View File

@@ -0,0 +1,47 @@
from collections.abc import AsyncIterator
from onyx.voice.interface import VoiceProviderInterface
class ElevenLabsVoiceProvider(VoiceProviderInterface):
"""ElevenLabs voice provider (placeholder - to be implemented)."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base or "https://api.elevenlabs.io"
self.stt_model = stt_model
self.tts_model = tts_model or "eleven_multilingual_v2"
self.default_voice = default_voice
async def transcribe(self, _audio_data: bytes, _audio_format: str) -> str:
raise NotImplementedError("ElevenLabs STT not yet implemented")
async def synthesize_stream(
self, _text: str, _voice: str | None = None, _speed: float = 1.0
) -> AsyncIterator[bytes]:
raise NotImplementedError("ElevenLabs TTS not yet implemented")
yield b"" # Required for async generator
def get_available_voices(self) -> list[dict[str, str]]:
# ElevenLabs voices are fetched dynamically via API
# Return empty list - frontend should fetch from /voices endpoint
return []
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "scribe_v1", "name": "Scribe v1"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
]

View File

@@ -0,0 +1,127 @@
import io
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from onyx.voice.interface import VoiceProviderInterface
if TYPE_CHECKING:
from openai import AsyncOpenAI
# OpenAI available voices for TTS
OPENAI_VOICES = [
{"id": "alloy", "name": "Alloy"},
{"id": "echo", "name": "Echo"},
{"id": "fable", "name": "Fable"},
{"id": "onyx", "name": "Onyx"},
{"id": "nova", "name": "Nova"},
{"id": "shimmer", "name": "Shimmer"},
]
# OpenAI available STT models
OPENAI_STT_MODELS = [
{"id": "whisper-1", "name": "Whisper v1"},
]
# OpenAI available TTS models
OPENAI_TTS_MODELS = [
{"id": "tts-1", "name": "TTS-1 (Standard)"},
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
]
class OpenAIVoiceProvider(VoiceProviderInterface):
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.stt_model = stt_model or "whisper-1"
self.tts_model = tts_model or "tts-1"
self.default_voice = default_voice or "alloy"
self._client: "AsyncOpenAI | None" = None
def _get_client(self) -> "AsyncOpenAI":
if self._client is None:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
return self._client
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using OpenAI Whisper.
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
client = self._get_client()
# Create a file-like object from the audio bytes
audio_file = io.BytesIO(audio_data)
audio_file.name = f"audio.{audio_format}"
response = await client.audio.transcriptions.create(
model=self.stt_model,
file=audio_file,
)
return response.text
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using OpenAI TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice identifier (defaults to provider's default voice)
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks (mp3 format)
"""
client = self._get_client()
# Clamp speed to valid range
speed = max(0.25, min(4.0, speed))
response = await client.audio.speech.create(
model=self.tts_model,
voice=voice or self.default_voice,
input=text,
speed=speed,
response_format="mp3",
)
# Stream the response content
async for chunk in response.iter_bytes(chunk_size=4096):
yield chunk
def get_available_voices(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS voices."""
return OPENAI_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
"""Get available OpenAI STT models."""
return OPENAI_STT_MODELS.copy()
def get_available_tts_models(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS models."""
return OPENAI_TTS_MODELS.copy()

View File

@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.0
onyx-devtools==0.5.7
# via onyx
openai==2.14.0
# via

View File

@@ -3,8 +3,8 @@ set -e
cleanup() {
echo "Error occurred. Cleaning up..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -55,10 +55,6 @@ else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
# Start the Code Interpreter container
echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"

View File

@@ -243,12 +243,12 @@ USAGE_LIMIT_CHUNKS_INDEXED_PAID = int(
)
# Per-week API calls using API keys or Personal Access Tokens
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "0"))
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "400"))
USAGE_LIMIT_API_CALLS_PAID = int(os.environ.get("USAGE_LIMIT_API_CALLS_PAID", "40000"))
# Per-week non-streaming API calls (more expensive, so lower limits)
USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL = int(
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "0")
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "80")
)
USAGE_LIMIT_NON_STREAMING_CALLS_PAID = int(
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_PAID", "160")

View File

@@ -26,7 +26,6 @@ class WebSearchProviderType(str, Enum):
SERPER = "serper"
EXA = "exa"
SEARXNG = "searxng"
BRAVE = "brave"
class WebContentProviderType(str, Enum):

View File

@@ -25,7 +25,6 @@ class ExpectedDocument:
content: str
folder_path: str | None = None
library: str = "Shared Documents" # Default to main library
expected_link_substrings: list[str] | None = None
EXPECTED_DOCUMENTS = [
@@ -33,29 +32,22 @@ EXPECTED_DOCUMENTS = [
semantic_identifier="test1.docx",
content="test1",
folder_path="test",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=test1.docx"],
),
ExpectedDocument(
semantic_identifier="test2.docx",
content="test2",
folder_path="test/nested with spaces",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=test2.docx"],
),
ExpectedDocument(
semantic_identifier="should-not-index-on-specific-folder.docx",
content="should-not-index-on-specific-folder",
folder_path=None, # root folder
expected_link_substrings=[
"_layouts/15/Doc.aspx",
"file=should-not-index-on-specific-folder.docx",
],
),
ExpectedDocument(
semantic_identifier="other.docx",
content="other",
folder_path=None,
library="Other Library",
expected_link_substrings=["_layouts/15/Doc.aspx", "file=other.docx"],
),
]
@@ -69,13 +61,11 @@ EXPECTED_PAGES = [
"Add a document library\n\n## Document library"
),
folder_path=None,
expected_link_substrings=["SitePages/CollabHome.aspx"],
),
ExpectedDocument(
semantic_identifier="Home",
content="# Home",
folder_path=None,
expected_link_substrings=["SitePages/Home.aspx"],
),
]
@@ -98,20 +88,6 @@ def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
assert len(doc.sections) == 1
assert doc.sections[0].text is not None
assert expected.content == doc.sections[0].text
if expected.expected_link_substrings is not None:
actual_link = doc.sections[0].link
assert actual_link is not None, (
f"Expected section link containing {expected.expected_link_substrings} "
f"for '{expected.semantic_identifier}', but link was None"
)
for substr in expected.expected_link_substrings:
assert substr in actual_link, (
f"Section link for '{expected.semantic_identifier}' "
f"missing expected substring '{substr}', "
f"actual link: '{actual_link}'"
)
verify_document_metadata(doc)

View File

@@ -1,281 +0,0 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -217,8 +217,8 @@ class TestAutoModeSyncFeature:
),
additional_visible_models=[
SimpleKnownModel(
name="claude-haiku-4-5",
display_name="Claude Haiku 4.5",
name="claude-3-5-haiku-latest",
display_name="Claude 3.5 Haiku",
)
],
),
@@ -260,7 +260,7 @@ class TestAutoModeSyncFeature:
# Anthropic models should NOT be present
assert "claude-3-5-sonnet-latest" not in model_names
assert "claude-haiku-4-5" not in model_names
assert "claude-3-5-haiku-latest" not in model_names
finally:
db_session.rollback()
@@ -485,7 +485,7 @@ class TestAutoModeSyncFeature:
# Provider 2 (Anthropic) config
provider_2_default_model = "claude-3-5-sonnet-latest"
provider_2_additional_models = ["claude-haiku-4-5"]
provider_2_additional_models = ["claude-3-5-haiku-latest"]
# Create mock recommendations with both providers
mock_recommendations = LLMRecommendations(

View File

@@ -281,22 +281,15 @@ def test_anthropic_prompt_caching_reduces_costs(
Anthropic requires explicit cache_control parameters.
"""
# Prompt caching support is model/account specific.
# Allow override via env var and otherwise try a few non-retired candidates.
anthropic_prompt_cache_models_env = os.environ.get("ANTHROPIC_PROMPT_CACHE_MODELS")
if anthropic_prompt_cache_models_env:
candidate_models = [
model.strip()
for model in anthropic_prompt_cache_models_env.split(",")
if model.strip()
]
else:
candidate_models = [
"claude-haiku-4-5-20251001",
"claude-sonnet-4-5-20250929",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
]
# Create Anthropic LLM
# NOTE: prompt caching support is model-specific; `claude-3-haiku-20240307` is known
# to return cache_creation/cache_read usage metrics, while some newer aliases may not.
llm = LitellmLLM(
api_key=os.environ["ANTHROPIC_API_KEY"],
model_provider="anthropic",
model_name="claude-3-haiku-20240307",
max_input_tokens=200000,
)
import random
import string
@@ -322,107 +315,79 @@ def test_anthropic_prompt_caching_reduces_costs(
UserMessage(role="user", content=long_context)
]
unavailable_models: list[str] = []
non_caching_models: list[str] = []
# First call - creates cache
print("\n=== First call (cache creation) ===")
question1: list[ChatCompletionMessage] = [
UserMessage(role="user", content="What are the main topics discussed?")
]
for model_name in candidate_models:
llm = LitellmLLM(
api_key=os.environ["ANTHROPIC_API_KEY"],
model_provider="anthropic",
model_name=model_name,
max_input_tokens=200000,
)
# First call - creates cache
print(f"\n=== First call (cache creation) model={model_name} ===")
question1: list[ChatCompletionMessage] = [
UserMessage(
role="user",
content="Reply with exactly one lowercase word: topics",
)
]
processed_messages1, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question1,
continuation=False,
)
try:
response1 = llm.invoke(prompt=processed_messages1, max_tokens=8)
except Exception as e:
error_str = str(e).lower()
if (
"not_found_error" in error_str
or "model_not_found" in error_str
or ('"type":"not_found_error"' in error_str and "model:" in error_str)
):
unavailable_models.append(model_name)
continue
raise
cost1 = completion_cost(
completion_response=response1.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage1 = response1.usage
print(f"Response 1 usage: {usage1}")
print(f"Cost 1: ${cost1:.10f}")
# Wait to ensure cache is available
time.sleep(2)
# Second call with same context - should use cache
print(f"\n=== Second call (cache read) model={model_name} ===")
question2: list[ChatCompletionMessage] = [
UserMessage(
role="user",
content="Reply with exactly one lowercase word: neural",
)
]
processed_messages2, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question2,
continuation=False,
)
response2 = llm.invoke(prompt=processed_messages2, max_tokens=8)
cost2 = completion_cost(
completion_response=response2.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage2 = response2.usage
print(f"Response 2 usage: {usage2}")
print(f"Cost 2: ${cost2:.10f}")
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
print(f"Cache read tokens (call 2): {cache_read_tokens}")
print(f"Cost reduction: ${cost1 - cost2:.10f}")
# Model is available but does not expose Anthropic cache usage metrics
if cache_creation_tokens <= 0 or cache_read_tokens <= 0:
non_caching_models.append(model_name)
continue
# Cost should be lower on second call
assert (
cost2 < cost1
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
return
pytest.skip(
"No Anthropic model available with observable prompt-cache metrics. "
f"Tried models={candidate_models}, unavailable={unavailable_models}, non_caching={non_caching_models}"
# Apply prompt caching
processed_messages1, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question1,
continuation=False,
)
response1 = llm.invoke(prompt=processed_messages1)
cost1 = completion_cost(
completion_response=response1.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage1 = response1.usage
print(f"Response 1 usage: {usage1}")
print(f"Cost 1: ${cost1:.10f}")
# Wait to ensure cache is available
time.sleep(2)
# Second call with same context - should use cache
print("\n=== Second call (cache read) ===")
question2: list[ChatCompletionMessage] = [
UserMessage(role="user", content="Can you elaborate on neural networks?")
]
# Apply prompt caching (same cacheable prefix)
processed_messages2, _ = process_with_prompt_cache(
llm_config=llm.config,
cacheable_prefix=base_messages,
suffix=question2,
continuation=False,
)
response2 = llm.invoke(prompt=processed_messages2)
cost2 = completion_cost(
completion_response=response2.model_dump(),
model=f"{llm._model_provider}/{llm._model_version}",
)
usage2 = response2.usage
print(f"Response 2 usage: {usage2}")
print(f"Cost 2: ${cost2:.10f}")
# Verify caching occurred
cache_creation_tokens = _get_usage_value(usage1, "cache_creation_input_tokens")
cache_read_tokens = _get_usage_value(usage2, "cache_read_input_tokens")
print(f"\nCache creation tokens (call 1): {cache_creation_tokens}")
print(f"Cache read tokens (call 2): {cache_read_tokens}")
print(f"Cost reduction: ${cost1 - cost2:.10f}")
# For Anthropic, we should see cache creation on first call and cache reads on second
assert (
cache_creation_tokens > 0
), f"Expected cache creation tokens on first call. Got: {cache_creation_tokens}"
assert (
cache_read_tokens > 0
), f"Expected cache read tokens on second call. Got: {cache_read_tokens}"
# Cost should be lower on second call
assert (
cost2 < cost1
), f"Expected lower cost on cached call. Cost 1: ${cost1:.10f}, Cost 2: ${cost2:.10f}"
@pytest.mark.skipif(
not os.environ.get(VERTEX_CREDENTIALS_ENV),

View File

@@ -13,7 +13,6 @@ from litellm.types.utils import ImageResponse
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
from onyx.llm.interfaces import LLMConfig
@@ -63,7 +62,6 @@ class MockImageGenerationProvider(
size: str, # noqa: ARG002
n: int, # noqa: ARG002
quality: str | None = None, # noqa: ARG002
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any, # noqa: ARG002
) -> ImageResponse:
image_data = self._images.pop(0)

View File

@@ -2,7 +2,6 @@ from collections.abc import Callable
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from pydantic import BaseModel
@@ -13,13 +12,9 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import SearchDoc
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def run_functions_tuples_sequential(
@@ -140,25 +135,13 @@ def use_mock_search_pipeline(
document_index: DocumentIndex, # noqa: ARG001
user: User | None, # noqa: ARG001
persona: Persona | None, # noqa: ARG001
db_session: Session | None = None, # noqa: ARG001
db_session: Session, # noqa: ARG001
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001
project_id: int | None = None, # noqa: ARG001
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
acl_filters: list[str] | None = None, # noqa: ARG001
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
prefetched_federated_retrieval_infos: ( # noqa: ARG001
list[FederatedRetrievalInfo] | None
) = None,
) -> list[InferenceChunk]:
return controller.get_search_results(chunk_search_request.query)
# Mock the pre-fetch session and DB queries in SearchTool.run() so
# tests don't need a fully initialised DB with search settings.
@contextmanager
def mock_get_session() -> Generator[MagicMock, None, None]:
yield MagicMock(spec=Session)
with (
patch(
"onyx.tools.tool_implementations.search.search_tool.search_pipeline",
@@ -200,31 +183,5 @@ def use_mock_search_pipeline(
"onyx.db.connector.fetch_unique_document_sources",
new=mock_fetch_unique_document_sources,
),
# Mock the pre-fetch phase of SearchTool.run()
patch(
"onyx.tools.tool_implementations.search.search_tool.get_session_with_current_tenant",
new=mock_get_session,
),
patch(
"onyx.tools.tool_implementations.search.search_tool.build_access_filters_for_user",
return_value=[],
),
patch(
"onyx.tools.tool_implementations.search.search_tool.get_current_search_settings",
return_value=MagicMock(spec=SearchSettings),
),
patch(
"onyx.tools.tool_implementations.search.search_tool.EmbeddingModel.from_db_model",
return_value=MagicMock(spec=EmbeddingModel),
),
patch(
"onyx.tools.tool_implementations.search.search_tool.get_federated_retrieval_functions",
return_value=[],
),
patch.object(
SearchTool,
"_prefetch_slack_data",
return_value=(None, None, {}),
),
):
yield controller

View File

@@ -943,18 +943,10 @@ from onyx.db.tools import get_builtin_tool
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.server.features.projects.api import upload_user_files
from onyx.server.query_and_chat.chat_backend import get_chat_session
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
from tests.external_dependency_unit.answer.stream_test_utils import create_placement
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.mock_llm import LLMAnswerResponse
from tests.external_dependency_unit.mock_llm import LLMToolCallResponse
from tests.external_dependency_unit.mock_llm import use_mock_llm
@@ -1182,134 +1174,3 @@ def test_code_interpreter_receives_chat_files(
assert execute_body["code"] == code
assert len(execute_body["files"]) == 1
assert execute_body["files"][0]["path"] == "data.csv"
def test_code_interpreter_replay_packets_include_code_and_output(
db_session: Session,
mock_ci_server: MockCodeInterpreterServer,
_attach_python_tool_to_default_persona: None,
initialize_file_store: None, # noqa: ARG001
) -> None:
"""After a code interpreter message completes, retrieving the message
via translate_assistant_message_to_packets should emit PythonToolStart
(containing the executed code) and PythonToolDelta (containing
stdout/stderr), not generic CustomTool packets."""
mock_ci_server.captured_requests.clear()
mock_ci_server._file_counter = 0
mock_url = mock_ci_server.url
user = create_test_user(db_session, "ci_replay_test")
chat_session = create_chat_session(db_session=db_session, user=user)
code = 'x = 2 + 2\nprint(f"Result: {x}")'
msg_req = SendMessageRequest(
message="Calculate 2 + 2",
chat_session_id=chat_session.id,
stream=True,
)
original_defaults = ci_mod.CodeInterpreterClient.__init__.__defaults__
with (
use_mock_llm() as mock_llm,
patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
mock_url,
),
patch(
"onyx.tools.tool_implementations.python.code_interpreter_client.CODE_INTERPRETER_BASE_URL",
mock_url,
),
):
answer_tokens = ["The ", "result ", "is ", "4."]
ci_mod.CodeInterpreterClient.__init__.__defaults__ = (mock_url,)
try:
handler = StreamTestBuilder(llm_controller=mock_llm)
stream = handle_stream_message_objects(
new_msg_req=msg_req, user=user, db_session=db_session
)
# First packet is always MessageResponseIDInfo
next(stream)
# Phase 1: LLM requests python tool execution.
handler.add_response(
LLMToolCallResponse(
tool_name="python",
tool_call_id="call_replay_test",
tool_call_argument_tokens=[json.dumps({"code": code})],
)
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
),
forward=2,
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolDelta(stdout="mock output\n", stderr="", file_ids=[]),
),
forward=False,
).expect(
Packet(
placement=create_placement(0),
obj=SectionEnd(),
),
forward=False,
).run_and_validate(
stream=stream
)
# Phase 2: LLM produces a final answer after tool execution.
handler.add_response(
LLMAnswerResponse(answer_tokens=answer_tokens)
).expect_agent_response(
answer_tokens=answer_tokens,
turn_index=1,
).run_and_validate(
stream=stream
)
with pytest.raises(StopIteration):
next(stream)
finally:
ci_mod.CodeInterpreterClient.__init__.__defaults__ = original_defaults
# Retrieve the chat session through the same endpoint the frontend uses
chat_detail = get_chat_session(
session_id=chat_session.id,
user=user,
db_session=db_session,
)
assert len(mock_ci_server.get_requests(method="POST", path="/v1/execute")) == 1
# The response contains `packets` — a list of packet-lists, one per
# assistant message. We should have exactly one assistant message.
assert (
len(chat_detail.packets) == 1
), f"Expected 1 assistant packet list, got {len(chat_detail.packets)}"
packets = chat_detail.packets[0]
# Extract PythonToolStart packets these must contain the code
start_packets = [p for p in packets if isinstance(p.obj, PythonToolStart)]
assert len(start_packets) == 1, (
f"Expected 1 PythonToolStart packet, got {len(start_packets)}. "
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
)
start_obj = start_packets[0].obj
assert isinstance(start_obj, PythonToolStart)
assert start_obj.code == code
# Extract PythonToolDelta packets these must contain stdout/stderr
delta_packets = [p for p in packets if isinstance(p.obj, PythonToolDelta)]
assert len(delta_packets) >= 1, (
f"Expected at least 1 PythonToolDelta packet, got {len(delta_packets)}. "
f"Packet types: {[type(p.obj).__name__ for p in packets]}"
)
# The mock CI server returns "mock output\n" as stdout
delta_obj = delta_packets[0].obj
assert isinstance(delta_obj, PythonToolDelta)
assert "mock output" in delta_obj.stdout

View File

@@ -13,9 +13,9 @@ from tests.integration.common_utils.test_models import DATestUser
class APIKeyManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: DATestUser | None = None,
) -> DATestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
@@ -25,7 +25,11 @@ class APIKeyManager:
api_key_response = requests.post(
f"{API_SERVER_URL}/admin/api-key",
json=api_key_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
@@ -44,21 +48,29 @@ class APIKeyManager:
@staticmethod
def delete(
api_key: DATestAPIKey,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
api_key_response.raise_for_status()
return [DATestAPIKey(**api_key) for api_key in api_key_response.json()]
@@ -66,8 +78,8 @@ class APIKeyManager:
@staticmethod
def verify(
api_key: DATestAPIKey,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action

View File

@@ -17,6 +17,7 @@ from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import DocumentSyncStatus
from tests.integration.common_utils.config import api_config
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
@@ -27,10 +28,10 @@ from tests.integration.common_utils.test_models import DATestUser
def _cc_pair_creator(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
@@ -39,12 +40,17 @@ def _cc_pair_creator(
connector_credential_pair_metadata = api.ConnectorCredentialPairMetadata(
name=name, access_type=access_type, groups=groups or []
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
api_response: api.StatusResponseInt = (
api_instance.associate_credential_to_connector(
connector_id,
credential_id,
connector_credential_pair_metadata,
_headers=user_performing_action.headers,
_headers=headers,
)
)
@@ -61,7 +67,6 @@ def _cc_pair_creator(
class CCPairManager:
@staticmethod
def create_from_scratch(
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
@@ -69,25 +74,26 @@ class CCPairManager:
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestCCPair:
connector = ConnectorManager.create(
user_performing_action=user_performing_action,
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
access_type=access_type,
groups=groups,
user_performing_action=user_performing_action,
refresh_freq=refresh_freq,
)
credential = CredentialManager.create(
user_performing_action=user_performing_action,
credential_json=credential_json,
name=name,
source=source,
curator_public=(access_type == AccessType.PUBLIC),
groups=groups,
user_performing_action=user_performing_action,
)
cc_pair = _cc_pair_creator(
connector_id=connector.id,
@@ -103,10 +109,10 @@ class CCPairManager:
def create(
connector_id: int,
credential_id: int,
user_performing_action: DATestUser,
name: str | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCCPair:
cc_pair = _cc_pair_creator(
connector_id=connector_id,
@@ -121,31 +127,39 @@ class CCPairManager:
@staticmethod
def pause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "PAUSED"},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def unpause_cc_pair(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "ACTIVE"},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
@@ -154,18 +168,26 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json=cc_pair_identifier.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def get_single(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> CCPairFullInfo | None:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
cc_pair_json = response.json()
@@ -174,11 +196,15 @@ class CCPairManager:
@staticmethod
def get_indexing_status_by_id(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> ConnectorIndexingStatusLite | None:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -193,11 +219,15 @@ class CCPairManager:
@staticmethod
def get_indexing_statuses(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorIndexingStatusLite]:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
json={"get_all_connectors": True},
)
response.raise_for_status()
@@ -211,11 +241,15 @@ class CCPairManager:
@staticmethod
def get_connector_statuses(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[ConnectorStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [ConnectorStatus(**status) for status in response.json()]
@@ -223,8 +257,8 @@ class CCPairManager:
@staticmethod
def verify(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_connector_statuses(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
@@ -251,7 +285,7 @@ class CCPairManager:
def run_once(
cc_pair: DATestCCPair,
from_beginning: bool,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
body = {
"connector_id": cc_pair.connector_id,
@@ -261,15 +295,19 @@ class CCPairManager:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector/run-once",
json=body,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -304,9 +342,9 @@ class CCPairManager:
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
@@ -355,8 +393,8 @@ class CCPairManager:
def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: Wait for an indexing success time after this time"""
start = time.monotonic()
@@ -392,22 +430,30 @@ class CCPairManager:
@staticmethod
def prune(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
@staticmethod
def last_pruned(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/last_pruned",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
response_str = response.json()
@@ -425,8 +471,8 @@ class CCPairManager:
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
@@ -450,7 +496,7 @@ class CCPairManager:
@staticmethod
def sync(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""This function triggers a permission sync.
Naming / intent of this function probably could use improvement, but currently it's letting
@@ -458,14 +504,22 @@ class CCPairManager:
"""
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if result.status_code != 409:
result.raise_for_status()
group_sync_result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if group_sync_result.status_code != 409:
group_sync_result.raise_for_status()
@@ -474,11 +528,15 @@ class CCPairManager:
@staticmethod
def get_doc_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
doc_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
doc_sync_response.raise_for_status()
doc_sync_response_str = doc_sync_response.json()
@@ -495,11 +553,15 @@ class CCPairManager:
@staticmethod
def get_group_sync_task(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> datetime | None:
group_sync_response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
group_sync_response.raise_for_status()
group_sync_response_str = group_sync_response.json()
@@ -516,11 +578,15 @@ class CCPairManager:
@staticmethod
def get_doc_sync_statuses(
cc_pair: DATestCCPair,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DocumentSyncStatus]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/get-docs-sync-status",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
doc_sync_statuses: list[DocumentSyncStatus] = []
@@ -547,9 +613,9 @@ class CCPairManager:
def wait_for_sync(
cc_pair: DATestCCPair,
after: datetime,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
number_of_updated_docs: int = 0,
user_performing_action: DATestUser | None = None,
# Sometimes waiting for a group sync is not necessary
should_wait_for_group_sync: bool = True,
# Sometimes waiting for a vespa sync is not necessary
@@ -637,8 +703,8 @@ class CCPairManager:
@staticmethod
def wait_for_deletion_completion(
user_performing_action: DATestUser,
cc_pair_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.

View File

@@ -17,6 +17,7 @@ from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.server.query_and_chat.streaming_models import StreamingType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestChatMessage
from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
@@ -73,9 +74,9 @@ class StreamPacketData(TypedDict, total=False):
class ChatSessionManager:
@staticmethod
def create(
user_performing_action: DATestUser,
persona_id: int = 0,
description: str = "Test chat session",
user_performing_action: DATestUser | None = None,
) -> DATestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
@@ -83,7 +84,11 @@ class ChatSessionManager:
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",
json=chat_session_creation_req.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
@@ -95,8 +100,8 @@ class ChatSessionManager:
def send_message(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -121,12 +126,19 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
response = requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=headers,
stream=True,
cookies=user_performing_action.cookies,
cookies=cookies,
)
streamed_response = ChatSessionManager.analyze_response(response)
@@ -155,9 +167,9 @@ class ChatSessionManager:
def send_message_with_disconnect(
chat_session_id: UUID,
message: str,
user_performing_action: DATestUser,
disconnect_after_packets: int = 0,
parent_message_id: int | None = None,
user_performing_action: DATestUser | None = None,
file_descriptors: list[FileDescriptor] | None = None,
allowed_tool_ids: list[int] | None = None,
forced_tool_ids: list[int] | None = None,
@@ -196,14 +208,21 @@ class ChatSessionManager:
llm_override=llm_override,
)
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
cookies = user_performing_action.cookies if user_performing_action else None
packets_received = 0
with requests.post(
f"{API_SERVER_URL}/chat/send-chat-message",
json=chat_message_req.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=headers,
stream=True,
cookies=user_performing_action.cookies,
cookies=cookies,
) as response:
for line in response.iter_lines():
if not line:
@@ -340,11 +359,15 @@ class ChatSessionManager:
@staticmethod
def get_chat_history(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -364,7 +387,7 @@ class ChatSessionManager:
def create_chat_message_feedback(
message_id: int,
is_positive: bool,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
feedback_text: str | None = None,
predefined_feedback: str | None = None,
) -> None:
@@ -376,14 +399,18 @@ class ChatSessionManager:
"feedback_text": feedback_text,
"predefined_feedback": predefined_feedback,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Delete a chat session and all its related records (messages, agent data, etc.)
@@ -393,14 +420,18 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def soft_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Soft delete a chat session (marks as deleted but keeps in database).
@@ -411,14 +442,18 @@ class ChatSessionManager:
# or make a direct call with hard_delete=False parameter via a new endpoint
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def hard_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Hard delete a chat session (completely removes from database).
@@ -427,14 +462,18 @@ class ChatSessionManager:
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def verify_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been deleted by attempting to retrieve it.
@@ -443,7 +482,11 @@ class ChatSessionManager:
"""
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@@ -451,7 +494,7 @@ class ChatSessionManager:
@staticmethod
def verify_soft_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
@@ -461,7 +504,11 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 200:
@@ -473,7 +520,7 @@ class ChatSessionManager:
@staticmethod
def verify_hard_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been hard deleted (completely removed from DB).
@@ -483,7 +530,11 @@ class ChatSessionManager:
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# For hard delete, even with include_deleted=true, the record should not exist

View File

@@ -8,6 +8,7 @@ from onyx.db.enums import AccessType
from onyx.server.documents.models import ConnectorUpdateRequest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestUser
@@ -15,13 +16,13 @@ from tests.integration.common_utils.test_models import DATestUser
class ConnectorManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
access_type: AccessType = AccessType.PUBLIC,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
refresh_freq: int | None = None,
) -> DATestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
@@ -50,7 +51,11 @@ class ConnectorManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector",
json=connector_update_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -68,33 +73,45 @@ class ConnectorManager:
@staticmethod
def edit(
connector: DATestConnector,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
json=connector.model_dump(exclude={"id"}),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
connector: DATestConnector,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestConnector]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [
@@ -110,12 +127,15 @@ class ConnectorManager:
@staticmethod
def get(
connector_id: int,
user_performing_action: DATestUser,
connector_id: int, user_performing_action: DATestUser | None = None
) -> DATestConnector:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
conn = response.json()

View File

@@ -6,6 +6,7 @@ import requests
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
@@ -13,13 +14,13 @@ from tests.integration.common_utils.test_models import DATestUser
class CredentialManager:
@staticmethod
def create(
user_performing_action: DATestUser,
credential_json: dict[str, Any] | None = None,
admin_public: bool = True,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
curator_public: bool = True,
groups: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestCredential:
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
@@ -35,7 +36,11 @@ class CredentialManager:
response = requests.post(
url=f"{API_SERVER_URL}/manage/credential",
json=credential_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -52,46 +57,61 @@ class CredentialManager:
@staticmethod
def edit(
credential: DATestCredential,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
request = credential.model_dump(include={"name", "credential_json"})
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
json=request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def delete(
credential: DATestCredential,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def get(
credential_id: int,
user_performing_action: DATestUser,
credential_id: int, user_performing_action: DATestUser | None = None
) -> CredentialSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return CredentialSnapshot(**response.json())
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[CredentialSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/manage/credential",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [CredentialSnapshot(**cred) for cred in response.json()]
@@ -99,8 +119,8 @@ class CredentialManager:
@staticmethod
def verify(
credential: DATestCredential,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_credentials = CredentialManager.get_all(user_performing_action)
for fetched_credential in all_credentials:

View File

@@ -10,6 +10,7 @@ from onyx.db.enums import AccessType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentByConnectorCredentialPair
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import DATestAPIKey
from tests.integration.common_utils.managers.cc_pair import DATestCCPair
@@ -21,9 +22,9 @@ from tests.integration.common_utils.vespa import vespa_fixture
def _verify_document_permissions(
retrieved_doc: dict,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
) -> None:
acl_keys = set(retrieved_doc.get("access_control_list", {}).keys())
print(f"ACL keys: {acl_keys}")
@@ -35,11 +36,12 @@ def _verify_document_permissions(
" does not have the PUBLIC ACL key"
)
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if doc_creating_user is not None:
if f"user_email:{doc_creating_user.email}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.email} but does not have the user_email:{doc_creating_user.email} ACL key"
)
if group_names is not None:
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
@@ -99,9 +101,9 @@ class DocumentManager:
@staticmethod
def seed_dummy_docs(
cc_pair: DATestCCPair,
api_key: DATestAPIKey,
num_docs: int = NUM_DOCS,
document_ids: list[str] | None = None,
api_key: DATestAPIKey | None = None,
) -> list[SimpleTestDocument]:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_ids is None:
@@ -116,13 +118,12 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(
f"Seeding docs for api_key_id={api_key.api_key_id} completed successfully."
)
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding docs for api_key_id={api_key_id} completed successfully.")
return [
SimpleTestDocument(
id=document["document"]["id"],
@@ -135,8 +136,8 @@ class DocumentManager:
def seed_doc_with_content(
cc_pair: DATestCCPair,
content: str,
api_key: DATestAPIKey,
document_id: str | None = None,
api_key: DATestAPIKey | None = None,
metadata: dict | None = None,
) -> SimpleTestDocument:
# Use provided document_ids if available, otherwise generate random UUIDs
@@ -152,13 +153,12 @@ class DocumentManager:
response = requests.post(
f"{API_SERVER_URL}/onyx-api/ingestion",
json=document,
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(
f"Seeding doc for api_key_id={api_key.api_key_id} completed successfully."
)
api_key_id = api_key.api_key_id if api_key else ""
print(f"Seeding doc for api_key_id={api_key_id} completed successfully.")
return SimpleTestDocument(
id=document["document"]["id"],
@@ -169,11 +169,11 @@ class DocumentManager:
def verify(
vespa_client: vespa_fixture,
cc_pair: DATestCCPair,
doc_creating_user: DATestUser,
# If None, will not check doc sets or groups
# If empty list, will check for empty doc sets or groups
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: DATestUser | None = None,
verify_deleted: bool = False,
) -> None:
doc_ids = [document.id for document in cc_pair.documents]
@@ -212,9 +212,9 @@ class DocumentManager:
_verify_document_permissions(
retrieved_doc,
cc_pair,
doc_creating_user,
doc_set_names,
group_names,
doc_creating_user,
)
@staticmethod
@@ -268,11 +268,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def list_all_ingestion_docs(
api_key: DATestAPIKey,
api_key: DATestAPIKey | None = None,
) -> list[dict]:
response = requests.get(
f"{API_SERVER_URL}/onyx-api/ingestion",
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
return response.json()
@@ -280,11 +280,11 @@ class IngestionManager(DocumentManager):
@staticmethod
def delete(
document_id: str,
api_key: DATestAPIKey,
api_key: DATestAPIKey | None = None,
) -> None:
response = requests.delete(
f"{API_SERVER_URL}/onyx-api/ingestion/{document_id}",
headers=api_key.headers,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print(f"Deleted document {document_id} successfully.")

View File

@@ -3,6 +3,7 @@ import requests
from ee.onyx.server.query_and_chat.models import SearchFullResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -10,7 +11,7 @@ class DocumentSearchManager:
@staticmethod
def search_documents(
query: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[str]:
"""
Search for documents using the EE search API.
@@ -30,7 +31,11 @@ class DocumentSearchManager:
result = requests.post(
url=f"{API_SERVER_URL}/search/send-search-message",
json=search_request.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
result.raise_for_status()
result_json = result.json()

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestDocumentSet
from tests.integration.common_utils.test_models import DATestUser
@@ -14,7 +15,6 @@ from tests.integration.common_utils.test_models import DATestUser
class DocumentSetManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
cc_pair_ids: list[int] | None = None,
@@ -22,6 +22,7 @@ class DocumentSetManager:
users: list[str] | None = None,
groups: list[int] | None = None,
federated_connectors: list[dict[str, Any]] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestDocumentSet:
if name is None:
name = f"test_doc_set_{str(uuid4())}"
@@ -39,7 +40,11 @@ class DocumentSetManager:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -58,7 +63,7 @@ class DocumentSetManager:
@staticmethod
def edit(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
doc_set_update_request = {
"id": document_set.id,
@@ -72,7 +77,11 @@ class DocumentSetManager:
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_update_request,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@@ -80,22 +89,30 @@ class DocumentSetManager:
@staticmethod
def delete(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestDocumentSet]:
response = requests.get(
f"{API_SERVER_URL}/manage/document-set",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [
@@ -115,8 +132,8 @@ class DocumentSetManager:
@staticmethod
def wait_for_sync(
user_performing_action: DATestUser,
document_sets_to_check: list[DATestDocumentSet] | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
# wait for document sets to be synced
start = time.time()
@@ -158,8 +175,8 @@ class DocumentSetManager:
@staticmethod
def verify(
document_set: DATestDocumentSet,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
doc_sets = DocumentSetManager.get_all(user_performing_action)
for doc_set in doc_sets:

View File

@@ -10,6 +10,7 @@ import requests
from onyx.file_store.models import FileDescriptor
from onyx.server.documents.models import FileUploadResponse
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -17,9 +18,13 @@ class FileManager:
@staticmethod
def upload_files(
files: List[Tuple[str, IO]],
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> Tuple[List[FileDescriptor], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
files_param = []
@@ -62,11 +67,15 @@ class FileManager:
@staticmethod
def fetch_uploaded_file(
file_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bytes:
response = requests.get(
f"{API_SERVER_URL}/chat/file/{file_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.content

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestImageGenerationConfig
from tests.integration.common_utils.test_models import DATestUser
@@ -25,7 +26,6 @@ def _serialize_custom_config(
class ImageGenerationConfigManager:
@staticmethod
def create(
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
provider: str = "openai",
@@ -35,6 +35,7 @@ class ImageGenerationConfigManager:
deployment_name: str | None = None,
custom_config: dict[str, Any] | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config with new credentials."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -52,7 +53,11 @@ class ImageGenerationConfigManager:
"custom_config": _serialize_custom_config(custom_config),
"is_default": is_default,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -69,13 +74,13 @@ class ImageGenerationConfigManager:
@staticmethod
def create_from_provider(
source_llm_provider_id: int,
user_performing_action: DATestUser,
image_provider_id: str | None = None,
model_name: str = "gpt-image-1",
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
is_default: bool = False,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Create a new image generation config by cloning from an existing LLM provider."""
image_provider_id = image_provider_id or f"test-provider-{uuid4()}"
@@ -91,7 +96,11 @@ class ImageGenerationConfigManager:
"deployment_name": deployment_name,
"is_default": is_default,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -107,12 +116,16 @@ class ImageGenerationConfigManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestImageGenerationConfig]:
"""Get all image generation configs."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [DATestImageGenerationConfig(**config) for config in response.json()]
@@ -120,12 +133,16 @@ class ImageGenerationConfigManager:
@staticmethod
def get_credentials(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> dict:
"""Get credentials for an image generation config."""
response = requests.get(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/credentials",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.json()
@@ -134,13 +151,13 @@ class ImageGenerationConfigManager:
def update(
image_provider_id: str,
model_name: str,
user_performing_action: DATestUser,
provider: str | None = None,
api_key: str | None = None,
source_llm_provider_id: int | None = None,
api_base: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestImageGenerationConfig:
"""Update an existing image generation config."""
payload: dict = {
@@ -161,10 +178,14 @@ class ImageGenerationConfigManager:
f"Got: source_llm_provider_id={source_llm_provider_id}, provider={provider}, api_key={'***' if api_key else None}"
)
headers = {**GENERAL_HEADERS}
if user_performing_action:
headers.update(user_performing_action.headers)
response = requests.put(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
json=payload,
headers=user_performing_action.headers,
headers=headers,
)
if not response.ok:
print(f"Update failed with status {response.status_code}: {response.text}")
@@ -183,32 +204,40 @@ class ImageGenerationConfigManager:
@staticmethod
def delete(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""Delete an image generation config."""
response = requests.delete(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def set_default(
image_provider_id: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> None:
"""Set an image generation config as the default."""
response = requests.post(
f"{API_SERVER_URL}/admin/image-generation/config/{image_provider_id}/default",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@staticmethod
def verify(
config: DATestImageGenerationConfig,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
"""Verify that a config exists (or doesn't exist if verify_deleted=True)."""
all_configs = ImageGenerationConfigManager.get_all(user_performing_action)

View File

@@ -14,6 +14,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestIndexAttempt
from tests.integration.common_utils.test_models import DATestUser
@@ -85,9 +86,9 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_page(
cc_pair_id: int,
user_performing_action: DATestUser,
page: int = 0,
page_size: int = 10,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[IndexAttemptSnapshot]:
query_params: dict[str, str | int] = {
"page_num": page,
@@ -100,7 +101,11 @@ class IndexAttemptManager:
)
response = requests.get(
url=url,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -112,7 +117,7 @@ class IndexAttemptManager:
@staticmethod
def get_latest_index_attempt_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot | None:
"""Get an IndexAttempt by ID"""
index_attempts = IndexAttemptManager.get_index_attempt_page(
@@ -129,9 +134,9 @@ class IndexAttemptManager:
@staticmethod
def wait_for_index_attempt_start(
cc_pair_id: int,
user_performing_action: DATestUser,
index_attempts_to_ignore: list[int] | None = None,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
"""Wait for an IndexAttempt to start"""
start = datetime.now()
@@ -159,7 +164,7 @@ class IndexAttemptManager:
def get_index_attempt_by_id(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> IndexAttemptSnapshot:
page_num = 0
page_size = 10
@@ -185,8 +190,8 @@ class IndexAttemptManager:
def wait_for_index_attempt_completion(
index_attempt_id: int,
cc_pair_id: int,
user_performing_action: DATestUser,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = time.monotonic()
@@ -218,15 +223,19 @@ class IndexAttemptManager:
@staticmethod
def get_index_attempt_errors_for_cc_pair(
cc_pair_id: int,
user_performing_action: DATestUser,
include_resolved: bool = True,
user_performing_action: DATestUser | None = None,
) -> list[IndexAttemptErrorPydantic]:
url = f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/errors?page_size=100"
if include_resolved:
url += "&include_resolved=true"
response = requests.get(
url=url,
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()

View File

@@ -8,6 +8,7 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -15,7 +16,6 @@ from tests.integration.common_utils.test_models import DATestUser
class LLMProviderManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
provider: str | None = None,
api_key: str | None = None,
@@ -26,8 +26,13 @@ class LLMProviderManager:
personas: list[int] | None = None,
is_public: bool | None = None,
set_as_default: bool = True,
user_performing_action: DATestUser | None = None,
) -> DATestLLMProvider:
print(f"Seeding LLM Providers for {user_performing_action.email}...")
email = "Unknown"
if user_performing_action:
email = user_performing_action.email
print(f"Seeding LLM Providers for {email}...")
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
@@ -55,7 +60,11 @@ class LLMProviderManager:
llm_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
json=llm_provider.model_dump(),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
llm_response.raise_for_status()
response_data = llm_response.json()
@@ -77,7 +86,11 @@ class LLMProviderManager:
if set_as_default:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -86,22 +99,30 @@ class LLMProviderManager:
@staticmethod
def delete(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
@@ -109,8 +130,8 @@ class LLMProviderManager:
@staticmethod
def verify(
llm_provider: DATestLLMProvider,
user_performing_action: DATestUser,
verify_deleted: bool = False,
user_performing_action: DATestUser | None = None,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
for fetched_llm_provider in all_llm_providers:

View File

@@ -7,6 +7,7 @@ from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestPersona
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
@@ -15,7 +16,6 @@ from tests.integration.common_utils.test_models import DATestUser
class PersonaManager:
@staticmethod
def create(
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -34,6 +34,7 @@ class PersonaManager:
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[str] | None = None,
user_performing_action: DATestUser | None = None,
display_priority: int | None = None,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
@@ -66,7 +67,11 @@ class PersonaManager:
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
persona_data = response.json()
@@ -95,7 +100,6 @@ class PersonaManager:
@staticmethod
def edit(
persona: DATestPersona,
user_performing_action: DATestUser,
name: str | None = None,
description: str | None = None,
system_prompt: str | None = None,
@@ -113,6 +117,7 @@ class PersonaManager:
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
user_performing_action: DATestUser | None = None,
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
@@ -146,7 +151,11 @@ class PersonaManager:
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request.model_dump(mode="json"),
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
updated_persona_data = response.json()
@@ -178,11 +187,15 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [FullPersonaSnapshot(**persona) for persona in response.json()]
@@ -190,11 +203,15 @@ class PersonaManager:
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [FullPersonaSnapshot(**response.json())]
@@ -202,7 +219,7 @@ class PersonaManager:
@staticmethod
def verify(
persona: DATestPersona,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_one(
persona_id=persona.id,
@@ -371,11 +388,15 @@ class PersonaManager:
@staticmethod
def delete(
persona: DATestPersona,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@@ -384,14 +405,18 @@ class PersonaLabelManager:
@staticmethod
def create(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaLabel:
response = requests.post(
f"{API_SERVER_URL}/persona/labels",
json={
"name": label.name,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
response_data = response.json()
@@ -400,11 +425,15 @@ class PersonaLabelManager:
@staticmethod
def get_all(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestPersonaLabel]:
response = requests.get(
f"{API_SERVER_URL}/persona/labels",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [DATestPersonaLabel(**label) for label in response.json()]
@@ -412,14 +441,18 @@ class PersonaLabelManager:
@staticmethod
def update(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> DATestPersonaLabel:
response = requests.patch(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
json={
"label_name": label.name,
},
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return label
@@ -427,18 +460,22 @@ class PersonaLabelManager:
@staticmethod
def delete(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/admin/persona/label/{label.id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def verify(
label: DATestPersonaLabel,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
all_labels = PersonaLabelManager.get_all(user_performing_action)
for fetched_label in all_labels:

View File

@@ -6,6 +6,7 @@ from onyx.server.features.projects.models import CategorizedFilesSnapshot
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.features.projects.models import UserProjectSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -19,7 +20,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/create",
params={"name": name},
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return UserProjectSnapshot.model_validate(response.json())
@@ -31,7 +32,7 @@ class ProjectManager:
"""Get all projects for a user via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -44,7 +45,7 @@ class ProjectManager:
"""Delete a project via API."""
response = requests.delete(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
return response.status_code == 204
@@ -56,7 +57,7 @@ class ProjectManager:
"""Verify that a project has been deleted by ensuring it's not in list."""
response = requests.get(
f"{API_SERVER_URL}/user/projects",
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
projects = [UserProjectSnapshot.model_validate(obj) for obj in response.json()]
@@ -65,12 +66,16 @@ class ProjectManager:
@staticmethod
def verify_files_unlinked(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""Verify that all files have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return True
@@ -82,12 +87,16 @@ class ProjectManager:
@staticmethod
def verify_chat_sessions_unlinked(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> bool:
"""Verify that all chat sessions have been unlinked from the project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return True
@@ -135,12 +144,16 @@ class ProjectManager:
@staticmethod
def get_project_files(
project_id: int,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> List[UserFileSnapshot]:
"""Get all files associated with a project via API."""
response = requests.get(
f"{API_SERVER_URL}/user/projects/files/{project_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 404:
return []
@@ -157,7 +170,7 @@ class ProjectManager:
response = requests.post(
f"{API_SERVER_URL}/user/projects/{project_id}/instructions",
json={"instructions": instructions},
headers=user_performing_action.headers,
headers=user_performing_action.headers or GENERAL_HEADERS,
)
response.raise_for_status()
return (response.json() or {}).get("instructions") or ""

View File

@@ -10,18 +10,19 @@ from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
class QueryHistoryManager:
@staticmethod
def get_query_history_page(
user_performing_action: DATestUser,
page_num: int = 0,
page_size: int = 10,
feedback_type: QAFeedbackType | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> PaginatedReturn[ChatSessionMinimal]:
query_params: dict[str, str | int] = {
"page_num": page_num,
@@ -36,7 +37,11 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
data = response.json()
@@ -48,20 +53,24 @@ class QueryHistoryManager:
@staticmethod
def get_chat_session_admin(
chat_session_id: UUID | str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> ChatSessionSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/admin/chat-session-history/{chat_session_id}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return ChatSessionSnapshot(**response.json())
@staticmethod
def get_query_history_as_csv(
user_performing_action: DATestUser,
start_time: datetime | None = None,
end_time: datetime | None = None,
user_performing_action: DATestUser | None = None,
) -> tuple[CaseInsensitiveDict[str], str]:
query_params: dict[str, str | int] = {}
if start_time:
@@ -71,7 +80,11 @@ class QueryHistoryManager:
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return response.headers, response.content.decode()

View File

@@ -5,6 +5,7 @@ from typing import Optional
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
@@ -12,9 +13,13 @@ from tests.integration.common_utils.test_models import DATestUser
class SettingsManager:
@staticmethod
def get_settings(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> tuple[Dict[str, Any], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
response = requests.get(
@@ -33,9 +38,13 @@ class SettingsManager:
@staticmethod
def update_settings(
settings: DATestSettings,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> tuple[Dict[str, Any], str]:
headers = user_performing_action.headers
headers = (
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
)
headers.pop("Content-Type", None)
payload = settings.model_dump()
@@ -56,7 +65,7 @@ class SettingsManager:
@staticmethod
def get_setting(
key: str,
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> Optional[Any]:
settings, error = SettingsManager.get_settings(user_performing_action)
if error:

View File

@@ -8,6 +8,7 @@ from onyx.server.manage.models import AllUsersResponse
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -25,11 +26,15 @@ def generate_auth_token() -> str:
class TenantManager:
@staticmethod
def get_all_users(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> AllUsersResponse:
response = requests.get(
url=f"{API_SERVER_URL}/manage/users",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
@@ -45,8 +50,7 @@ class TenantManager:
@staticmethod
def verify_user_in_tenant(
user: DATestUser,
user_performing_action: DATestUser,
user: DATestUser, user_performing_action: DATestUser | None = None
) -> None:
all_users = TenantManager.get_all_users(user_performing_action)
for accepted_user in all_users.accepted:

View File

@@ -1,6 +1,7 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestTool
from tests.integration.common_utils.test_models import DATestUser
@@ -8,11 +9,15 @@ from tests.integration.common_utils.test_models import DATestUser
class ToolManager:
@staticmethod
def list_tools(
user_performing_action: DATestUser,
user_performing_action: DATestUser | None = None,
) -> list[DATestTool]:
response = requests.get(
url=f"{API_SERVER_URL}/tool",
headers=user_performing_action.headers,
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
return [

Some files were not shown because too many files have changed in this diff Show More