Compare commits

..

72 Commits

Author SHA1 Message Date
Jessica Singh
dc5bba094c fix: use str() for StrEnum comparisons to satisfy mypy 2026-03-11 18:41:23 -07:00
Jessica Singh
a9028e5ae8 fix: resolve merge conflict in InputComboBox hooks 2026-03-11 18:26:35 -07:00
Jessica Singh
4fdbef4185 Merge branch 'main' into voice-mode 2026-03-11 18:23:17 -07:00
Jessica Singh
9cac41bb6b fix: chain voice migration after latest main head 2026-03-11 18:09:39 -07:00
Jamison Lahman
a78607f1b5 fix(fe): InputComboBox resets filter value on open (#9287) 2026-03-12 01:06:02 +00:00
Jessica Singh
1fce5b6bf5 fix: resolve merge conflict in AppInputBar with main's useQueryController API 2026-03-11 16:59:48 -07:00
Jessica Singh
6b035b5908 Merge branch 'main' into voice-mode 2026-03-11 16:57:09 -07:00
Jessica Singh
764394d5cf fix: address valid cubic review comments
- Fix OpenAI transcript wait exiting early on prior transcript
- Cancel synthesis task on flush timeout instead of orphaning it
- Clean up resources on VoiceRecorderSession start failure
- Move ref mutations from useMemo to useEffect in MessageTextRenderer
- Add messageNodeId to RendererComponent memo comparison
- Fix TOCTOU race in WS token rate limiting (INCR-first)
- Assert actual clamped values in playback speed tests
- Update soft-delete test to match new pattern
2026-03-11 16:28:42 -07:00
roshan
e213853f63 fix(craft): rename webapp download endpoint to avoid route conflict (#9283)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Wenxi <wenxi@onyx.app>
2026-03-11 23:19:38 +00:00
Jessica Singh
2217d0ab48 fix: use flex flow for waveform positioning instead of absolute offsets 2026-03-11 16:07:23 -07:00
Jessica Singh
e6e05681a3 refactor: address PR review follow-up items for voice mode
- Add soft delete (deleted column) to VoiceProvider following Persona pattern
- Extract voice admin API calls into lib/admin/voice/svc.ts
- Create dedicated useVoiceProviders SWR hook
- Update voice admin page and modal to use svc.ts and hook
- Extract magic numbers in VoiceModeProvider to named constants
- Add 5-minute TTS session safety timeout

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 15:44:22 -07:00
Wenxi
8dc379c6fd feat(ods): use release-tag to print highest stable semver that should receive the latest tag (#9278)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-11 22:18:13 +00:00
Jessica Singh
feef46f9f3 test: add unit tests for voice providers
Add comprehensive unit tests for voice provider implementations:
- OpenAI: URL conversion, WAV headers, model defaults (11 tests)
- ElevenLabs: resampling, model validation, defaults (13 tests)
- Azure: region parsing, URL construction, SSRF protection (25 tests)
- API validation: update to use OnyxError, add None handling (4 tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-11 15:06:46 -07:00
dependabot[bot]
787f117e17 chore(deps): bump pypdf from 6.7.5 to 6.8.0 (#9260)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-11 21:59:35 +00:00
Jamison Lahman
665640fac8 chore(opensearch): unset container ulimits in dev (#9277)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-11 21:58:43 +00:00
Jessica Singh
243514f601 test: add comprehensive unit tests for voice db module
Add 37 unit tests covering all functions in onyx.db.voice:
- fetch_voice_providers, fetch_voice_provider_by_id
- fetch_default_stt_provider, fetch_default_tts_provider
- fetch_voice_provider_by_type
- upsert_voice_provider (create/update/error cases)
- delete_voice_provider
- set_default_stt_provider, set_default_tts_provider
- deactivate_stt_provider, deactivate_tts_provider
- update_user_voice_settings (including speed clamping)

Tests follow the mocked session pattern used in test_scim_dal.py.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-11 14:58:00 -07:00
Danelegend
d2d44c1e68 fix(indexing): Stop deep-copy during indexing (#9275) 2026-03-11 21:24:15 +00:00
Jessica Singh
72eb5c5626 refactor: address PR review comments for voice mode
- Replace ValueError with OnyxError in db/voice.py
- Add StrEnum message types for ElevenLabs and OpenAI providers
- Extract magic numbers into named constants (sample rates, VAD thresholds, etc.)
- Add SSML_NAMESPACE constant in Azure provider
- Add validate_credentials to VoiceProviderInterface and implementations
- Add module docstrings to all voice providers
- Add console.error alongside toast.error calls in frontend voice components
- Add explanatory comments for VoiceModeProvider placement, auth tokens,
  WebSocket routing, and VAD behavior
- Replace silent catch with console.warn in VoiceModeProvider cleanup

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 14:06:41 -07:00
Nikolas Garza
ffe04ab91f fix(tests): remove deprecated o1-preview and o1-mini model tests (#9280) 2026-03-11 20:32:51 +00:00
Raunak Bhagat
6499b21235 feat(opal): add Card and EmptyMessageCard components (#9271) 2026-03-11 13:14:17 -07:00
Nikolas Garza
c5bfd5a152 feat(admin): add Users page shell with stats bar and SCIM card - 1/9 (#9079) 2026-03-11 16:28:47 +00:00
Jessica Singh
eb4f806a44 real waveform 2026-03-11 09:16:25 -07:00
Jessica Singh
b4ab51e307 refactor: move elapsed time formatting to dateUtils
Addresses PR review comment - extract MM:SS formatting from
RecordingWaveform into a reusable formatElapsedTime utility.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-11 00:04:12 -07:00
Jessica Singh
a321abf8a6 fix: voice mode bugs - hide when unconfigured, fix button style, guard auto-playback
- Add GET /api/voice/status endpoint so frontend can check if STT/TTS is configured
- Hide microphone button and TTS play button when providers are not set up
- Fix TTSButton using primary prominence (too dark) - now uses tertiary
- Guard auto-playback in VoiceModeProvider against unconfigured TTS
- Clear input text after voice auto-send to prevent stale text in next iteration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-10 20:59:14 -07:00
Justin Tahara
a0329161b0 feat(litellm): Adding FE Provider workflow (#9264) 2026-03-11 03:45:08 +00:00
Raunak Bhagat
334b7a6d2f feat(opal): add foldable support to OpenButton + fix MessageToolbar (#9265) 2026-03-11 03:00:51 +00:00
dependabot[bot]
36196373a8 chore(deps): bump hono from 4.12.5 to 4.12.7 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9263)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-10 18:54:17 -07:00
Jamison Lahman
533aa8eff8 chore(release): upgrade release-tag (#9257) 2026-03-11 00:50:55 +00:00
Raunak Bhagat
ecbb267f80 fix: Consolidate search state-machine (#9234) 2026-03-11 00:42:39 +00:00
Danelegend
66023dbb6d feat(llm-provider): fetch litellm models (#8418) 2026-03-10 23:48:56 +00:00
Wenxi
f97466e4de chore: redeclare cache_okay for EncryptedBase children (#9253) 2026-03-10 23:44:51 +00:00
Evan Lohn
2cc8303e5f chore: sharepoint dedupe (#9254) 2026-03-10 23:41:51 +00:00
Wenxi
a92ff61f64 chore: add cache_okay to EncryptedJson (#9252) 2026-03-10 22:18:39 +00:00
acaprau
17551a907e fix(opensearch): Update should clear projects and personas when they are empty (#8845) 2026-03-10 21:49:55 +00:00
Jamison Lahman
9e42951fa4 fix(fe): increase responsive breakpoint for centering modals (#9250) 2026-03-10 21:45:23 +00:00
acaprau
dcb18c2411 chore(opensearch): Followup for #9243 (#9247) 2026-03-10 14:31:44 -07:00
Jamison Lahman
2f628e39d3 fix(fe): correctly parse comma literals in CSVs (#9245) 2026-03-10 21:03:47 +00:00
Nikolas Garza
fd200d46f8 fix(storybook): case-sensitivity, icon rename, and story fixes (#9244) 2026-03-10 20:05:32 +00:00
Evan Lohn
ec7482619b fix: update jira group sync endpoint (#9241) 2026-03-10 19:57:01 +00:00
Jamison Lahman
9d1a357533 fix(fe): make CSV inline display responsive (#9242) 2026-03-10 19:42:23 +00:00
acaprau
fbe823b551 chore(opensearch): Allow configuring num hits from hybrid subquery from env var (#9243) 2026-03-10 19:27:36 +00:00
acaprau
1608e2f274 fix(opensearch): Allow configuring the page size of chunks we get from Vespa during migration (#9239) 2026-03-10 17:51:52 +00:00
Jamison Lahman
4dbb1fa606 chore(tests): fix nightly model-server tests (#9236) 2026-03-10 17:49:08 +00:00
Jessica Singh
9573274039 fix: mypy type ignore for azure speech SDK import 2026-03-09 16:26:48 -07:00
Jessica Singh
4f7b1332e2 fix: merge main into voice-mode with Disabled wrapper pattern 2026-03-09 15:46:50 -07:00
Jessica Singh
0dac14abd4 Merge branch 'main' into voice-mode 2026-03-09 15:41:13 -07:00
Jessica Singh
aa0eac8ae8 pr comments 2026-03-09 15:22:15 -07:00
Jessica Singh
8fead4dfbf jest test 2026-03-05 21:55:58 -08:00
Jessica Singh
ac4b49a7f9 sync to autoplayback text generation 2026-03-05 21:50:07 -08:00
Jessica Singh
fc22232f14 fix: mypy type errors in websocket_api 2026-03-05 17:41:15 -08:00
Jessica Singh
9ddd44bf56 fix: prevent UnboundLocalError in TTS fallback 2026-03-05 17:27:00 -08:00
Jessica Singh
8587911cf6 fix: critical bugs from PR review 2026-03-05 17:25:52 -08:00
Jessica Singh
d7300d50d7 fix: address PR review feedback 2026-03-05 17:19:21 -08:00
Jessica Singh
cc950a2da2 fix: session lifecycle, atomic WS token, and clearable voice prefs 2026-03-05 16:32:38 -08:00
Jessica Singh
8d6640159a fix: address PR bot review feedback (security, dead code, logging) 2026-03-05 16:09:31 -08:00
Jessica Singh
bba77749c3 fix: position recording bar above input on subsequent turns 2026-03-05 14:33:23 -08:00
Jessica Singh
3e9a66c8ff chore: add @types/sbd for TypeScript support 2026-03-05 14:25:12 -08:00
Jessica Singh
548b9d9e0e fix: remove unused type ignore in azure.py 2026-03-05 14:13:42 -08:00
Jessica Singh
0d3967baee mypy 2026-03-05 13:50:27 -08:00
Jessica Singh
6ed806eebb migration 2026-03-05 09:32:35 -08:00
Jessica Singh
3b6a35b2c4 recording bar + bug fixes 2026-03-04 21:55:42 -08:00
Jessica Singh
62e612f85f Merge branch 'main' into voice-mode 2026-03-04 21:25:37 -08:00
Jessica Singh
b375b7f0ff azure 2026-03-04 20:14:18 -08:00
Jessica Singh
c158ae2622 remove logs 2026-03-04 17:02:44 -08:00
Jessica Singh
698494626f eleven labs and bug fixes 2026-03-04 16:40:26 -08:00
Jessica Singh
93cefe7ef0 chore: trigger Greptile review 2026-03-03 23:04:53 -08:00
Jessica Singh
8a326c4089 address greptile review feedback (greploop iteration 2)
- Narrow WebSocket auth bypass to only voice endpoints in auth_check.py
- Add query param validation (max_length, ge/le) for TTS synthesize endpoint
- Fix ObjectURL memory leak in useVoicePlayback.ts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:39:12 -08:00
Jessica Singh
0c5410f429 address greptile review feedback (greploop iteration 1)
- Add WebSocket authentication to /voice/transcribe/stream and /voice/synthesize/stream endpoints
- Fix useVoicePlayback.ts to use query params instead of JSON body (matches API signature)
- Fix delete_voice_provider to use flush() instead of commit() for consistency
- Disable Azure streaming STT until audio resampling is implemented
- Add SSML escaping to prevent injection in Azure TTS
- Remove debug console.log statements from voice components
- Fix blob URL memory leak in VoiceModeProvider

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:31:59 -08:00
Jessica Singh
0b05b9b235 fix(voice): move error toast to useEffect 2026-03-03 17:26:20 -08:00
Jessica Singh
59d8a988bd streaming tts 2026-03-03 03:37:18 -08:00
Jessica Singh
6d08cfb25a all changes 2026-03-02 13:16:39 -08:00
Jessica Singh
53a5ee2a6e stt and tts 2026-02-23 18:27:37 -08:00
173 changed files with 14328 additions and 2120 deletions

View File

@@ -48,7 +48,7 @@ jobs:
- name: Deploy to Vercel (Production)
working-directory: web
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN"
notify-slack-on-failure:
needs: Deploy-Storybook

View File

@@ -0,0 +1,117 @@
"""add_voice_provider_and_user_voice_prefs
Revision ID: 93a2e195e25c
Revises: b5c4d7e8f9a1
Create Date: 2026-02-23 15:16:39.507304
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import column
from sqlalchemy import true
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "93a2e195e25c"
down_revision = "b5c4d7e8f9a1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create voice_provider table
op.create_table(
"voice_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), unique=True, nullable=False),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
sa.Column("stt_model", sa.String(), nullable=True),
sa.Column("tts_model", sa.String(), nullable=True),
sa.Column("default_voice", sa.String(), nullable=True),
sa.Column(
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
nullable=False,
),
)
# Add partial unique indexes to enforce only one default STT/TTS provider
op.create_index(
"ix_voice_provider_one_default_stt",
"voice_provider",
["is_default_stt"],
unique=True,
postgresql_where=column("is_default_stt") == true(),
)
op.create_index(
"ix_voice_provider_one_default_tts",
"voice_provider",
["is_default_tts"],
unique=True,
postgresql_where=column("is_default_tts") == true(),
)
# Add voice preference columns to user table
op.add_column(
"user",
sa.Column(
"voice_auto_send",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_auto_playback",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_playback_speed",
sa.Float(),
default=1.0,
nullable=False,
server_default="1.0",
),
)
def downgrade() -> None:
# Remove user voice preference columns
op.drop_column("user", "voice_playback_speed")
op.drop_column("user", "voice_auto_playback")
op.drop_column("user", "voice_auto_send")
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
# Drop voice_provider table
op.drop_table("voice_provider")

View File

@@ -1,6 +1,8 @@
from collections.abc import Generator
from typing import Any
from jira import JIRA
from jira.exceptions import JIRAError
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.jira.utils import build_jira_client
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
_GROUP_MEMBER_PAGE_SIZE = 50
def _get_jira_group_members_email(
# The GET /group/member endpoint was introduced in Jira 6.0.
# Jira versions older than 6.0 do not have group management REST APIs at all.
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
def _fetch_group_member_page(
jira_client: JIRA,
group_name: str,
) -> list[str]:
"""Get all member emails for a Jira group.
start_at: int,
) -> dict[str, Any]:
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
Filters out app accounts (bots, integrations) and only returns real user emails.
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
directly via the library's internal _get_json helper, following the same pattern
as enhanced_search_ids / bulk_fetch_issues in connector.py.
There is an open PR to the library to switch to this endpoint since last year:
https://github.com/pycontribs/jira/pull/2356
so once it is merged and released, we can switch to using the library function.
"""
emails: list[str] = []
try:
# group_members returns an OrderedDict of account_id -> member_info
members = jira_client.group_members(group=group_name)
if not members:
logger.warning(f"No members found for group {group_name}")
return emails
for account_id, member_info in members.items():
# member_info is a dict with keys like 'fullname', 'email', 'active'
email = member_info.get("email")
# Skip "hidden" emails - these are typically app accounts
if email and email != "hidden":
emails.append(email)
else:
# For cloud, we might need to fetch user details separately
try:
user = jira_client.user(id=account_id)
# Skip app accounts (bots, integrations, etc.)
if hasattr(user, "accountType") and user.accountType == "app":
logger.info(
f"Skipping app account {account_id} for group {group_name}"
)
continue
if hasattr(user, "emailAddress") and user.emailAddress:
emails.append(user.emailAddress)
else:
logger.warning(f"User {account_id} has no email address")
except Exception as e:
logger.warning(
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
return emails
return jira_client._get_json(
"group/member",
params={
"groupname": group_name,
"includeInactiveUsers": "false",
"startAt": start_at,
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
},
)
except JIRAError as e:
if e.status_code == 404:
raise RuntimeError(
f"GET /group/member returned 404 for group '{group_name}'. "
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
f"If you are running a self-hosted Jira instance, please upgrade "
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
) from e
raise
def _build_group_member_email_map(
def _get_group_member_emails(
jira_client: JIRA,
) -> dict[str, set[str]]:
"""Build a map of group names to member emails."""
group_member_emails: dict[str, set[str]] = {}
group_name: str,
) -> set[str]:
"""Get all member emails for a single Jira group.
try:
# Get all groups from Jira - returns a list of group name strings
group_names = jira_client.groups()
Uses the non-deprecated GET /group/member endpoint which returns full user
objects including accountType, so we can filter out app/customer accounts
without making separate user() calls.
"""
emails: set[str] = set()
start_at = 0
if not group_names:
logger.warning("No groups found in Jira")
return group_member_emails
while True:
try:
page = _fetch_group_member_page(jira_client, group_name, start_at)
except Exception as e:
logger.error(f"Error fetching members for group {group_name}: {e}")
raise
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
members: list[dict[str, Any]] = page.get("values", [])
for member in members:
account_type = member.get("accountType")
# On Jira DC < 9.0, accountType is absent; include those users.
# On Cloud / DC 9.0+, filter to real user accounts only.
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
continue
member_emails = _get_jira_group_members_email(
jira_client=jira_client,
group_name=group_name,
)
if member_emails:
group_member_emails[group_name] = set(member_emails)
logger.debug(
f"Found {len(member_emails)} members for group {group_name}"
)
email = member.get("emailAddress")
if email:
emails.add(email)
else:
logger.debug(f"No members found for group {group_name}")
logger.warning(
f"Atlassian user {member.get('accountId', 'unknown')} "
f"in group {group_name} has no visible email address"
)
except Exception as e:
logger.error(f"Error building group member email map: {e}")
if page.get("isLast", True) or not members:
break
start_at += len(members)
return group_member_emails
return emails
def jira_group_sync(
tenant_id: str, # noqa: ARG001
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""
Sync Jira groups and their members.
"""Sync Jira groups and their members, yielding one group at a time.
This function fetches all groups from Jira and yields ExternalUserGroup
objects containing the group ID and member emails.
Streams group-by-group rather than accumulating all groups in memory.
"""
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
scoped_token = cc_pair.connector.connector_specific_config.get(
@@ -130,12 +127,26 @@ def jira_group_sync(
scoped_token=scoped_token,
)
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
if not group_member_email_map:
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
group_names = jira_client.groups()
if not group_names:
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
for group_id, group_member_emails in group_member_email_map.items():
yield ExternalUserGroup(
id=group_id,
user_emails=list(group_member_emails),
logger.info(f"Found {len(group_names)} groups in Jira")
for group_name in group_names:
if not group_name:
continue
member_emails = _get_group_member_emails(
jira_client=jira_client,
group_name=group_name,
)
if not member_emails:
logger.debug(f"No members found for group {group_name}")
continue
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
yield ExternalUserGroup(
id=group_name,
user_emails=list(member_emails),
)

View File

@@ -29,6 +29,7 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi import WebSocket
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
@@ -121,6 +122,7 @@ from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import retrieve_ws_token_data
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
@@ -1612,6 +1614,102 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
return user
async def _get_user_from_token_data(token_data: dict) -> User | None:
"""Shared logic: token data dict → User object.
Args:
token_data: Decoded token data containing 'sub' (user ID).
Returns:
User object if found and active, None otherwise.
"""
user_id = token_data.get("sub")
if not user_id:
return None
try:
user_uuid = uuid.UUID(user_id)
except ValueError:
return None
async with get_async_session_context_manager() as async_db_session:
user = await async_db_session.get(User, user_uuid)
if user is None or not user.is_active:
return None
return user
async def current_user_from_websocket(
websocket: WebSocket,
token: str = Query(..., description="WebSocket authentication token"),
) -> User:
"""
WebSocket authentication dependency using query parameter.
Validates the WS token from query param and returns the User.
Raises BasicAuthenticationError if authentication fails.
The token must be obtained from POST /voice/ws-token before connecting.
Tokens are single-use and expire after 60 seconds.
Usage:
1. POST /voice/ws-token -> {"token": "xxx"}
2. Connect to ws://host/path?token=xxx
This applies the same auth checks as current_user() for HTTP endpoints.
"""
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
# Browsers always send Origin on WebSocket connections
origin = websocket.headers.get("origin")
expected_origin = WEB_DOMAIN.rstrip("/")
if not origin:
logger.warning("WS auth: missing Origin header")
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
actual_origin = origin.rstrip("/")
if actual_origin != expected_origin:
logger.warning(
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
)
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
# Validate WS token in Redis (single-use, deleted after retrieval)
try:
token_data = await retrieve_ws_token_data(token)
if token_data is None:
raise BasicAuthenticationError(
detail="Access denied. Invalid or expired authentication token."
)
except BasicAuthenticationError:
raise
except Exception as e:
logger.error(f"WS auth: error during token validation: {e}")
raise BasicAuthenticationError(
detail="Authentication verification failed."
) from e
# Get user from token data
user = await _get_user_from_token_data(token_data)
if user is None:
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
raise BasicAuthenticationError(
detail="Access denied. User not found or inactive."
)
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
user = await double_check_user(user)
# Block LIMITED users (same as current_user)
if user.role == UserRole.LIMITED:
logger.warning(f"WS auth: user {user.email} has LIMITED role")
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
logger.debug(f"WS auth: authenticated {user.email}")
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Onyx MIT
return []

View File

@@ -11,6 +11,9 @@
# lock after its cleanup which happens at most after its soft timeout.
# Constants corresponding to migrate_documents_from_vespa_to_opensearch_task.
from onyx.configs.app_configs import OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes.
MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes.
# The maximum time the lock can be held for. Will automatically be released
@@ -44,7 +47,7 @@ TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_PAGE_SIZE = OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the

View File

@@ -311,6 +311,12 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
== "true"
)
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
)
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a

View File

@@ -258,6 +258,10 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
# Track yielded hierarchy nodes by their raw_node_id (URLs) to avoid duplicates
seen_hierarchy_node_raw_ids: set[str] = Field(default_factory=set)
# Track yielded document IDs to avoid processing the same document twice.
# The Microsoft Graph delta API can return the same item on multiple pages.
seen_document_ids: set[str] = Field(default_factory=set)
class SharepointAuthMethod(Enum):
CLIENT_SECRET = "client_secret"
@@ -1557,6 +1561,7 @@ class SharepointConnector(
checkpoint.current_drive_id = None
checkpoint.current_drive_web_url = None
checkpoint.current_drive_delta_next_link = None
checkpoint.seen_document_ids.clear()
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self.site_descriptors or self.fetch_sites()
@@ -2137,6 +2142,14 @@ class SharepointConnector(
item_count = 0
for driveitem in driveitems:
item_count += 1
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
logger.debug(
f"Skipping duplicate document {driveitem.id} "
f"({driveitem.name})"
)
continue
driveitem_extension = get_file_ext(driveitem.name)
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
logger.warning(
@@ -2189,11 +2202,13 @@ class SharepointConnector(
if isinstance(doc_or_failure, Document):
if doc_or_failure.sections:
checkpoint.seen_document_ids.add(doc_or_failure.id)
yield doc_or_failure
elif should_yield_if_empty:
doc_or_failure.sections = [
TextSection(link=driveitem.web_url, text="")
]
checkpoint.seen_document_ids.add(doc_or_failure.id)
yield doc_or_failure
else:
logger.warning(

View File

@@ -25,6 +25,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
@@ -369,9 +370,9 @@ def upsert_llm_provider(
def sync_model_configurations(
db_session: Session,
provider_name: str,
models: list[dict],
models: list[SyncModelEntry],
) -> int:
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
This inserts NEW models from the source API without overwriting existing ones.
User preferences (is_visible, max_input_tokens) are preserved for existing models.
@@ -379,7 +380,7 @@ def sync_model_configurations(
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
models: List of SyncModelEntry objects describing the fetched models
Returns:
Number of new models added
@@ -393,21 +394,20 @@ def sync_model_configurations(
new_count = 0
for model in models:
model_name = model["name"]
if model_name not in existing_names:
if model.name not in existing_names:
# Insert new model with is_visible=False (user must explicitly enable)
supported_flows = [LLMModelFlowType.CHAT]
if model.get("supports_image_input", False):
if model.supports_image_input:
supported_flows.append(LLMModelFlowType.VISION)
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
model_name=model_name,
model_name=model.name,
supported_flows=supported_flows,
is_visible=False,
max_input_tokens=model.get("max_input_tokens"),
display_name=model.get("display_name"),
max_input_tokens=model.max_input_tokens,
display_name=model.display_name,
)
new_count += 1

View File

@@ -163,6 +163,8 @@ class _EncryptedBase(TypeDecorator):
class EncryptedString(_EncryptedBase):
# Must redeclare cache_ok in this child class since we explicitly redeclare _is_json
cache_ok = True
_is_json: bool = False
def process_bind_param(
@@ -189,6 +191,7 @@ class EncryptedString(_EncryptedBase):
class EncryptedJson(_EncryptedBase):
cache_ok = True
_is_json: bool = True
def process_bind_param(
@@ -340,6 +343,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# Voice preferences
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
@@ -3052,6 +3060,65 @@ class ImageGenerationConfig(Base):
)
class VoiceProvider(Base):
"""Configuration for voice services (STT and TTS)."""
__tablename__ = "voice_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider_type: Mapped[str] = mapped_column(
String
) # "openai", "azure", "elevenlabs"
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Model/voice configuration
stt_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "whisper-1"
tts_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "tts-1", "tts-1-hd"
default_voice: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "alloy", "echo"
# STT and TTS can use different providers - only one provider per type
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
# Enforce only one default STT provider and one default TTS provider at DB level
__table_args__ = (
Index(
"ix_voice_provider_one_default_stt",
"is_default_stt",
unique=True,
postgresql_where=(is_default_stt == True), # noqa: E712
),
Index(
"ix_voice_provider_one_default_tts",
"is_default_tts",
unique=True,
postgresql_where=(is_default_tts == True), # noqa: E712
),
)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"

248
backend/onyx/db/voice.py Normal file
View File

@@ -0,0 +1,248 @@
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import VoiceProvider
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
MIN_VOICE_PLAYBACK_SPEED = 0.5
MAX_VOICE_PLAYBACK_SPEED = 2.0
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(
select(VoiceProvider)
.where(VoiceProvider.deleted.is_(False))
.order_by(VoiceProvider.name)
).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int, include_deleted: bool = False
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
if not include_deleted:
stmt = stmt.where(VoiceProvider.deleted.is_(False))
return db_session.scalar(stmt)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_stt.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.is_default_tts.is_(True))
.where(VoiceProvider.deleted.is_(False))
)
def fetch_voice_provider_by_type(
db_session: Session, provider_type: str
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider)
.where(VoiceProvider.provider_type == provider_type)
.where(VoiceProvider.deleted.is_(False))
)
def upsert_voice_provider(
*,
db_session: Session,
provider_id: int | None,
name: str,
provider_type: str,
api_key: str | None,
api_key_changed: bool,
api_base: str | None = None,
custom_config: dict[str, Any] | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
activate_stt: bool = False,
activate_tts: bool = False,
) -> VoiceProvider:
"""Create or update a voice provider."""
provider: VoiceProvider | None = None
if provider_id is not None:
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"No voice provider with id {provider_id} exists.",
)
else:
provider = VoiceProvider()
db_session.add(provider)
# Apply updates
provider.name = name
provider.provider_type = provider_type
provider.api_base = api_base
provider.custom_config = custom_config
provider.stt_model = stt_model
provider.tts_model = tts_model
provider.default_voice = default_voice
# Only update API key if explicitly changed or if provider has no key
if api_key_changed or provider.api_key is None:
provider.api_key = api_key # type: ignore[assignment]
db_session.flush()
if activate_stt:
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
if activate_tts:
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
db_session.refresh(provider)
return provider
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Soft-delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
provider.deleted = True
db_session.flush()
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Set a voice provider as the default STT provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"No voice provider with id {provider_id} exists.",
)
# Deactivate all other STT providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_stt.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_stt=False)
)
# Activate this provider
provider.is_default_stt = True
db_session.flush()
db_session.refresh(provider)
return provider
def set_default_tts_provider(
*, db_session: Session, provider_id: int, tts_model: str | None = None
) -> VoiceProvider:
"""Set a voice provider as the default TTS provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"No voice provider with id {provider_id} exists.",
)
# Deactivate all other TTS providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_tts.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_tts=False)
)
# Activate this provider
provider.is_default_tts = True
# Update the TTS model if specified
if tts_model is not None:
provider.tts_model = tts_model
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default STT status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"No voice provider with id {provider_id} exists.",
)
provider.is_default_stt = False
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default TTS status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"No voice provider with id {provider_id} exists.",
)
provider.is_default_tts = False
db_session.flush()
db_session.refresh(provider)
return provider
# User voice preferences
def update_user_voice_settings(
db_session: Session,
user_id: UUID,
auto_send: bool | None = None,
auto_playback: bool | None = None,
playback_speed: float | None = None,
) -> None:
"""Update user's voice settings.
For all fields, None means "don't update this field".
"""
values: dict[str, bool | float] = {}
if auto_send is not None:
values["voice_auto_send"] = auto_send
if auto_playback is not None:
values["voice_auto_playback"] = auto_playback
if playback_speed is not None:
values["voice_playback_speed"] = max(
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
)
if values:
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
db_session.flush()

View File

@@ -1,5 +1,10 @@
# Default value for the maximum number of tokens a chunk can hold, if none is
# specified when creating an index.
from onyx.configs.app_configs import (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
DEFAULT_MAX_CHUNK_SIZE = 512
# Size of the dynamic list used to consider elements during kNN graph creation.
@@ -10,27 +15,43 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# When performing hybrid search, we need to consider more candidates than the
# number of results to be returned. This is because the scoring is hybrid and
# the results are reordered due to the hybrid scoring. Higher = more candidates
# for hybrid fusion = better retrieval accuracy, but results in more computation
# per query. Imagine a simple case with a single keyword query and a single
# vector query and we want 10 final docs. If we only fetch 10 candidates from
# each of keyword and vector, they would have to have perfect overlap to get a
# good hybrid ranking for the 10 results. If we fetch 1000 candidates from each,
# we have a much higher chance of all 10 of the final desired docs showing up
# and getting scored. In worse situations, the final 10 docs don't even show up
# as the final 10 (worse than just a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
else 750
)
# Number of vectors to examine for top k neighbors for the HNSW method.
# Number of vectors to examine to decide the top k neighbors for the HNSW
# method.
# NOTE: "When creating a search query, you must specify k. If you provide both k
# and ef_search, then the larger value is passed to the engine. If ef_search is
# larger than k, you can provide the size parameter to limit the final number of
# results to k." from
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
# Since the titles are included in the contents, the embedding matches are
# heavily downweighted as they act as a boost rather than an independent scoring
# component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
# Single keyword weight for both title and content (merged from former title
# keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
# NOTE: It is critical that the order of these weights matches the order of the
# sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,

View File

@@ -433,12 +433,16 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
hidden=fields.hidden if fields else None,
project_ids=(
set(user_fields.user_projects)
if user_fields and user_fields.user_projects
# NOTE: Empty user_projects is semantically different from None
# user_projects.
if user_fields and user_fields.user_projects is not None
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
# NOTE: Empty personas is semantically different from None
# personas.
if user_fields and user_fields.personas is not None
else None
),
)

View File

@@ -255,8 +255,12 @@ class DocumentQuery:
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
)
# TODO(andrei, yuhong): We can tune this more dynamically based on
# num_hits.
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector
query_text, query_vector, vector_candidates=max_results_per_subquery
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -285,13 +289,16 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Max results per subquery per shard before aggregation. Ensures
# keyword and vector subqueries contribute equally to the
# candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
"pagination_depth": max_results_per_subquery,
# Applied to all the sub-queries independently (this avoids
# subqueries having a lot of results thrown out during
# aggregation).
# Sources:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
@@ -374,9 +381,10 @@ class DocumentQuery:
def _get_hybrid_search_subqueries(
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
# The default number of neighbors to consider for knn vector similarity
# search. This is higher than the number of results because the scoring
# is hybrid. For a detailed breakdown, see where the default value is
# set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -400,20 +408,27 @@ class DocumentQuery:
in a single hybrid query. Source:
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
NOTE: Each query is independent during the search phase, there is no backfilling of scores for missing query components.
What this means is that if a document was a good vector match but did not show up for keyword, it gets a score of 0 for
the keyword component of the hybrid scoring. This is not as bad as just disregarding a score though as there is
normalization applied after. So really it is "increasing" the missing score compared to if it was included and the range
was renormalized. This does however mean that between docs that have high scores for say the vector field, the keyword
scores between them are completely ignored unless they also showed up in the keyword query as a reasonably high match.
TLDR, this is a bit of unique funky behavior but it seems ok.
NOTE: Each query is independent during the search phase, there is no
backfilling of scores for missing query components. What this means is
that if a document was a good vector match but did not show up for
keyword, it gets a score of 0 for the keyword component of the hybrid
scoring. This is not as bad as just disregarding a score though as there
is normalization applied after. So really it is "increasing" the missing
score compared to if it was included and the range was renormalized.
This does however mean that between docs that have high scores for say
the vector field, the keyword scores between them are completely ignored
unless they also showed up in the keyword query as a reasonably high
match. TLDR, this is a bit of unique funky behavior but it seems ok.
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
- minimum_should_match: Since it's hybrid search and users often provide
semantic queries, there is often a lot of terms, and very low number
of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: Typo tolerance (0/1/2 edit distance by term length).
It's mostly for typos as the analyzer ("english" by default) already
does some stemming and tokenization. In testing datasets, this makes
recall slightly worse. It also is less performant so not really any
reason to do it.
Args:
query_text: The text of the query to search for.
@@ -723,14 +738,13 @@ class DocumentQuery:
# document's metadata list.
filter_clauses.append(_get_tag_filter(tags))
# Knowledge scope: explicit knowledge attachments restrict what
# an assistant can see. When none are set the assistant
# searches everything.
# Knowledge scope: explicit knowledge attachments restrict what an
# assistant can see. When none are set the assistant searches
# everything.
#
# project_id / persona_id are additive: they make overflowing
# user files findable but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
# project_id / persona_id are additive: they make overflowing user files
# findable but must NOT trigger the restriction on their own (an agent
# with no explicit knowledge should search everything).
has_knowledge_scope = (
attached_document_ids
or hierarchy_node_ids
@@ -758,9 +772,8 @@ class DocumentQuery:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
# Additive: widen scope to also cover overflowing user
# files, but only when an explicit restriction is already
# in effect.
# Additive: widen scope to also cover overflowing user files, but
# only when an explicit restriction is already in effect.
if project_id is not None:
knowledge_filter["bool"]["should"].append(
_get_user_project_filter(project_id)

View File

@@ -690,9 +690,12 @@ class VespaIndex(DocumentIndex):
)
project_ids: set[int] | None = None
# NOTE: Empty user_projects is semantically different from None
# user_projects.
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
# NOTE: Empty personas is semantically different from None personas.
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(

View File

@@ -66,6 +66,11 @@ class OnyxErrorCode(Enum):
RATE_LIMITED = ("RATE_LIMITED", 429)
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
# ------------------------------------------------------------------
# Payload (413)
# ------------------------------------------------------------------
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
# ------------------------------------------------------------------
# Connector / Credential Errors (400-range)
# ------------------------------------------------------------------

View File

@@ -1,4 +1,3 @@
import csv
import gc
import io
import json
@@ -20,7 +19,6 @@ from zipfile import BadZipFile
import chardet
import openpyxl
from openpyxl.worksheet.worksheet import Worksheet
from PIL import Image
from onyx.configs.constants import ONYX_METADATA_FILENAME
@@ -354,65 +352,6 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
return presentation.markdown
def _worksheet_to_matrix(
worksheet: Worksheet,
) -> list[list[str]]:
"""
Converts a singular worksheet to a matrix of values
"""
rows: list[list[str]] = []
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
row = ["" if cell is None else str(cell) for cell in worksheet_row]
rows.append(row)
return rows
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
"""
Cleans a worksheet matrix by removing rows if there are N consecutive empty
rows and removing cols if there are M consecutive empty columns
"""
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
MAX_EMPTY_COLS = 2
# Row cleanup
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
# Column cleanup (transpose, clean, transpose back)
transposed = list(map(list, zip(*matrix))) if matrix else []
transposed = _remove_empty_runs(transposed, max_empty=MAX_EMPTY_COLS)
matrix = list(map(list, zip(*transposed))) if transposed else []
return matrix
def _remove_empty_runs(
rows: list[list[str]],
max_empty: int,
) -> list[list[str]]:
"""Removes entire runs of empty rows when the run length exceeds max_empty.
Leading and trailing empty rows are always dropped regardless of run length,
since there is no adjacent non-empty row to bound the run.
"""
result: list[list[str]] = []
empty_buffer: list[list[str]] = []
for row in rows:
# Check if empty
if not any(row):
empty_buffer.append(row)
else:
# Add upto max empty rows onto the result - that's what we allow
result.extend(empty_buffer[:max_empty])
# Add the new non-empty row
result.append(row)
empty_buffer = []
return result
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
# TODO: switch back to this approach in a few months when markitdown
# fixes their handling of excel files
@@ -451,15 +390,30 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise
raise e
text_content = []
for sheet in workbook.worksheets:
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
buf = io.StringIO()
writer = csv.writer(buf, lineterminator="\n")
writer.writerows(sheet_matrix)
text_content.append(buf.getvalue().rstrip("\n"))
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)

View File

@@ -123,15 +123,11 @@ class DocumentIndexingBatchAdapter:
}
doc_id_to_new_chunk_cnt: dict[str, int] = {
document_id: len(
[
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
)
for document_id in updatable_ids
doc_id: 0 for doc_id in updatable_ids
}
for chunk in chunks_with_embeddings:
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
# Get ancestor hierarchy node IDs for each document
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(

View File

@@ -16,6 +16,7 @@ from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.pydantic_util import shallow_model_dump
from onyx.utils.timing import log_function_time
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
@@ -210,8 +211,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
)[0]
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.model_dump(),
new_embedded_chunk = IndexChunk.model_construct(
**shallow_model_dump(chunk),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import Document
from onyx.db.enums import EmbeddingPrecision
from onyx.db.enums import SwitchoverType
from onyx.utils.logger import setup_logger
from onyx.utils.pydantic_util import shallow_model_dump
from shared_configs.enums import EmbeddingProvider
from shared_configs.model_server_models import Embedding
@@ -133,9 +134,8 @@ class DocMetadataAwareIndexChunk(IndexChunk):
tenant_id: str,
ancestor_hierarchy_node_ids: list[int] | None = None,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
return cls(
**index_chunk_data,
return cls.model_construct(
**shallow_model_dump(index_chunk),
access=access,
document_sets=document_sets,
user_project=user_project,

View File

@@ -43,6 +43,7 @@ WELL_KNOWN_PROVIDER_NAMES = [
LlmProviderNames.AZURE,
LlmProviderNames.OLLAMA_CHAT,
LlmProviderNames.LM_STUDIO,
LlmProviderNames.LITELLM_PROXY,
]
@@ -59,6 +60,7 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
"ollama": "Ollama",
LlmProviderNames.OLLAMA_CHAT: "Ollama",
LlmProviderNames.LM_STUDIO: "LM Studio",
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
"groq": "Groq",
"anyscale": "Anyscale",
"deepseek": "DeepSeek",
@@ -109,6 +111,7 @@ AGGREGATOR_PROVIDERS: set[str] = {
LlmProviderNames.LM_STUDIO,
LlmProviderNames.VERTEX_AI,
LlmProviderNames.AZURE,
LlmProviderNames.LITELLM_PROXY,
}
# Model family name mappings for display name generation

View File

@@ -11,6 +11,8 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
LM_STUDIO_PROVIDER_NAME = "lm_studio"
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
# Providers that use optional Bearer auth from custom_config
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,

View File

@@ -15,6 +15,7 @@ from onyx.llm.well_known_providers.auto_update_service import (
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
@@ -47,6 +48,7 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
}
@@ -331,6 +333,7 @@ def get_provider_display_name(provider_name: str) -> str:
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
OPENROUTER_PROVIDER_NAME: "OpenRouter",
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
}
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:

View File

@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.voice.api import admin_router as voice_admin_router
from onyx.server.manage.voice.user_api import router as voice_router
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(application, voice_admin_router)
include_router_with_global_prefix_prepended(application, voice_router)
include_router_with_global_prefix_prepended(application, voice_websocket_router)
include_router_with_global_prefix_prepended(
application, opensearch_migration_admin_router
)

View File

@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
return _async_redis_connection
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
async def retrieve_auth_token_data(token: str) -> dict | None:
"""Validate auth token against Redis and return token data.
Args:
token: The raw authentication token string.
Returns:
Token data dict if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
@@ -439,13 +442,97 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
logger.error("Error decoding token data from Redis")
return None
except Exception as e:
logger.error(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
raise ValueError(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
return await retrieve_auth_token_data(token)
# WebSocket token prefix (separate from regular auth tokens)
REDIS_WS_TOKEN_PREFIX = "ws_token:"
# WebSocket tokens expire after 60 seconds
WS_TOKEN_TTL_SECONDS = 60
# Rate limit: max tokens per user per window
WS_TOKEN_RATE_LIMIT_MAX = 10
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
class WsTokenRateLimitExceeded(Exception):
"""Raised when a user exceeds the WS token generation rate limit."""
async def store_ws_token(token: str, user_id: str) -> None:
"""Store a short-lived WebSocket authentication token in Redis.
Args:
token: The generated WS token.
user_id: The user ID to associate with this token.
Raises:
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
"""
redis = await get_async_redis_connection()
# Atomically increment and check rate limit to avoid TOCTOU races
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
pipe = redis.pipeline()
pipe.incr(rate_limit_key)
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
results = await pipe.execute()
new_count = results[0]
if new_count > WS_TOKEN_RATE_LIMIT_MAX:
# Over limit — decrement back since we won't use this slot
await redis.decr(rate_limit_key)
logger.warning(f"WS token rate limit exceeded for user {user_id}")
raise WsTokenRateLimitExceeded(
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
)
# Store the actual token
redis_key = REDIS_WS_TOKEN_PREFIX + token
token_data = json.dumps({"sub": user_id})
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
async def retrieve_ws_token_data(token: str) -> dict | None:
"""Validate a WebSocket token and return the token data.
This uses GETDEL for atomic get-and-delete to prevent race conditions
where the same token could be used twice.
Args:
token: The WS token to validate.
Returns:
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_WS_TOKEN_PREFIX + token
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
token_data_str = await redis.getdel(redis_key)
if not token_data_str:
return None
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding WS token data from Redis")
return None
except Exception as e:
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
return None
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
# diagnostic logging for lock errors

View File

@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_from_websocket
from onyx.auth.users import current_user_with_expired_token
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -129,6 +130,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accessible_user
or depends_fn == current_user_from_websocket
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token

View File

@@ -732,7 +732,7 @@ def get_webapp_info(
return WebappInfo(**webapp_info)
@router.get("/{session_id}/webapp/download")
@router.get("/{session_id}/webapp-download")
def download_webapp(
session_id: UUID,
user: User = Depends(current_user),

View File

@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"version": "4.12.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"

View File

@@ -58,6 +58,9 @@ from onyx.llm.well_known_providers.llm_provider_options import (
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelDetails
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
@@ -72,6 +75,7 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
from onyx.server.manage.llm.models import SyncModelEntry
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.server.manage.llm.utils import generate_bedrock_display_name
@@ -98,6 +102,34 @@ def _mask_string(value: str) -> str:
return value[:4] + "****" + value[-4:]
def _sync_fetched_models(
db_session: Session,
provider_name: str,
models: list[SyncModelEntry],
source_label: str,
) -> None:
"""Sync fetched models to DB for the given provider.
Args:
db_session: Database session
provider_name: Name of the LLM provider
models: List of SyncModelEntry objects describing the fetched models
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
"""
try:
new_count = sync_model_configurations(
db_session=db_session,
provider_name=provider_name,
models=models,
)
if new_count > 0:
logger.info(
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
)
except ValueError as e:
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
# Keys in custom_config that contain sensitive credentials
_SENSITIVE_CONFIG_KEYS = {
"vertex_credentials",
@@ -963,27 +995,20 @@ def get_bedrock_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
for r in results
],
source_label="Bedrock",
)
return results
@@ -1101,27 +1126,20 @@ def get_ollama_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
for r in sorted_results
],
source_label="Ollama",
)
return sorted_results
@@ -1210,27 +1228,20 @@ def get_openrouter_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
for r in sorted_results
],
source_label="OpenRouter",
)
return sorted_results
@@ -1324,26 +1335,119 @@ def get_lm_studio_available_models(
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
models_to_sync = [
{
"name": r.name,
"display_name": r.display_name,
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
provider_name=request.provider_name,
models=models_to_sync,
)
if new_count > 0:
logger.info(
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.name,
display_name=r.display_name,
max_input_tokens=r.max_input_tokens,
supports_image_input=r.supports_image_input,
)
except ValueError as e:
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
for r in sorted_results
],
source_label="LM Studio",
)
return sorted_results
@admin_router.post("/litellm/available-models")
def get_litellm_available_models(
request: LitellmModelsRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LitellmFinalModelResponse]:
"""Fetch available models from Litellm proxy /v1/models endpoint."""
response_json = _get_litellm_models_response(
api_key=request.api_key, api_base=request.api_base
)
models = response_json.get("data", [])
if not isinstance(models, list) or len(models) == 0:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No models found from your Litellm endpoint",
)
results: list[LitellmFinalModelResponse] = []
for model in models:
try:
model_details = LitellmModelDetails.model_validate(model)
results.append(
LitellmFinalModelResponse(
provider_name=model_details.owned_by,
model_name=model_details.id,
)
)
except Exception as e:
logger.warning(
"Failed to parse Litellm model entry",
extra={"error": str(e), "item": str(model)[:1000]},
)
if not results:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No compatible models found from Litellm",
)
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
# Sync new models to DB if provider_name is specified
if request.provider_name:
_sync_fetched_models(
db_session=db_session,
provider_name=request.provider_name,
models=[
SyncModelEntry(
name=r.model_name,
display_name=r.model_name,
)
for r in sorted_results
],
source_label="LiteLLM",
)
return sorted_results
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/v1/models"
headers = {
"Authorization": f"Bearer {api_key}",
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
)
elif e.response.status_code == 404:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"LiteLLM models endpoint not found at {url}. "
"Please verify the API base URL.",
)
else:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)
except Exception as e:
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
f"Failed to fetch LiteLLM models: {e}",
)

View File

@@ -420,3 +420,32 @@ class LLMProviderResponse(BaseModel, Generic[T]):
default_text=default_text,
default_vision=default_vision,
)
class SyncModelEntry(BaseModel):
"""Typed model for syncing fetched models to the DB."""
name: str
display_name: str
max_input_tokens: int | None = None
supports_image_input: bool = False
class LitellmModelsRequest(BaseModel):
api_key: str
api_base: str
provider_name: str | None = None # Optional: to save models to existing provider
class LitellmModelDetails(BaseModel):
"""Response model for Litellm proxy /api/v1/models endpoint"""
id: str # Model ID (e.g. "gpt-4o")
object: str # "model"
created: int # Unix timestamp in seconds
owned_by: str # Provider name (e.g. "openai")
class LitellmFinalModelResponse(BaseModel):
provider_name: str # Provider name (e.g. "openai")
model_name: str # Model ID (e.g. "gpt-4o")

View File

@@ -85,6 +85,11 @@ class UserPreferences(BaseModel):
chat_background: str | None = None
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
# Voice preferences
voice_auto_send: bool | None = None
voice_auto_playback: bool | None = None
voice_playback_speed: float | None = None
# controls which tools are enabled for the user for a specific assistant
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
@@ -164,6 +169,9 @@ class UserInfo(BaseModel):
theme_preference=user.theme_preference,
chat_background=user.chat_background,
default_app_mode=user.default_app_mode,
voice_auto_send=user.voice_auto_send,
voice_auto_playback=user.voice_auto_playback,
voice_playback_speed=user.voice_playback_speed,
assistant_specific_configs=assistant_specific_configs,
)
),
@@ -240,6 +248,12 @@ class ChatBackgroundRequest(BaseModel):
chat_background: str | None
class VoiceSettingsUpdateRequest(BaseModel):
auto_send: bool | None = None
auto_playback: bool | None = None
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
class PersonalizationUpdateRequest(BaseModel):
name: str | None = None
role: str | None = None

View File

@@ -0,0 +1,315 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import User
from onyx.db.models import VoiceProvider
from onyx.db.voice import deactivate_stt_provider
from onyx.db.voice import deactivate_tts_provider
from onyx.db.voice import delete_voice_provider
from onyx.db.voice import fetch_voice_provider_by_id
from onyx.db.voice import fetch_voice_provider_by_type
from onyx.db.voice import fetch_voice_providers
from onyx.db.voice import set_default_stt_provider
from onyx.db.voice import set_default_tts_provider
from onyx.db.voice import upsert_voice_provider
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.voice.models import VoiceOption
from onyx.server.manage.voice.models import VoiceProviderTestRequest
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
from onyx.server.manage.voice.models import VoiceProviderView
from onyx.utils.logger import setup_logger
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/voice")
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
"""Validate and normalize provider api_base / target URI."""
if api_base is None:
return None
allow_private_network = provider_type.lower() == "azure"
try:
return validate_outbound_http_url(
api_base, allow_private_network=allow_private_network
)
except (ValueError, SSRFException) as e:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Invalid target URI: {str(e)}",
) from e
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
"""Convert a VoiceProvider model to a VoiceProviderView."""
return VoiceProviderView(
id=provider.id,
name=provider.name,
provider_type=provider.provider_type,
is_default_stt=provider.is_default_stt,
is_default_tts=provider.is_default_tts,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
has_api_key=bool(provider.api_key),
target_uri=provider.api_base, # api_base stores the target URI for Azure
)
@admin_router.get("/providers")
def list_voice_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VoiceProviderView]:
"""List all configured voice providers."""
providers = fetch_voice_providers(db_session)
return [_provider_to_view(provider) for provider in providers]
@admin_router.post("/providers")
async def upsert_voice_provider_endpoint(
request: VoiceProviderUpsertRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Create or update a voice provider."""
api_key = request.api_key
api_key_changed = request.api_key_changed
# If llm_provider_id is specified, copy the API key from that LLM provider
if request.llm_provider_id is not None:
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
if llm_provider is None:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"LLM provider with id {request.llm_provider_id} not found.",
)
if llm_provider.api_key is None:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Selected LLM provider has no API key configured.",
)
api_key = llm_provider.api_key.get_value(apply_mask=False)
api_key_changed = True
# Use target_uri if provided, otherwise fall back to api_base
api_base = _validate_voice_api_base(
request.provider_type, request.target_uri or request.api_base
)
provider = upsert_voice_provider(
db_session=db_session,
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
api_base=api_base,
custom_config=request.custom_config,
stt_model=request.stt_model,
tts_model=request.tts_model,
default_voice=request.default_voice,
activate_stt=request.activate_stt,
activate_tts=request.activate_tts,
)
# Validate credentials before committing - rollback on failure
try:
voice_provider = get_voice_provider(provider)
await voice_provider.validate_credentials()
except Exception as e:
db_session.rollback()
logger.error(f"Voice provider credential validation failed on save: {e}")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Connection test failed. Please verify your API key and settings.",
) from e
db_session.commit()
return _provider_to_view(provider)
@admin_router.delete(
"/providers/{provider_id}", status_code=204, response_class=Response
)
def delete_voice_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
"""Delete a voice provider."""
delete_voice_provider(db_session, provider_id)
db_session.commit()
return Response(status_code=204)
@admin_router.post("/providers/{provider_id}/activate-stt")
def activate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default STT provider."""
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-stt")
def deactivate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderUpdateSuccess:
"""Remove the default STT status from a voice provider."""
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return VoiceProviderUpdateSuccess()
@admin_router.post("/providers/{provider_id}/activate-tts")
def activate_tts_provider_endpoint(
provider_id: int,
tts_model: str | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default TTS provider."""
provider = set_default_tts_provider(
db_session=db_session, provider_id=provider_id, tts_model=tts_model
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-tts")
def deactivate_tts_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderUpdateSuccess:
"""Remove the default TTS status from a voice provider."""
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return VoiceProviderUpdateSuccess()
@admin_router.post("/providers/test")
async def test_voice_provider(
request: VoiceProviderTestRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderUpdateSuccess:
"""Test a voice provider connection by making a real API call."""
api_key = request.api_key
if request.use_stored_key:
existing_provider = fetch_voice_provider_by_type(
db_session, request.provider_type
)
if existing_provider is None or not existing_provider.api_key:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No stored API key found for this provider type.",
)
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"API key is required. Either provide api_key or set use_stored_key to true.",
)
# Use target_uri if provided, otherwise fall back to api_base
api_base = _validate_voice_api_base(
request.provider_type, request.target_uri or request.api_base
)
# Create a temporary VoiceProvider for testing (not saved to DB)
temp_provider = VoiceProvider(
name="__test__",
provider_type=request.provider_type,
api_base=api_base,
custom_config=request.custom_config or {},
)
temp_provider.api_key = api_key # type: ignore[assignment]
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
# Validate credentials with a real API call
try:
await provider.validate_credentials()
except OnyxError:
raise
except Exception as e:
logger.error(f"Voice provider connection test failed: {e}")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Connection test failed. Please verify your API key and settings.",
) from e
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
return VoiceProviderUpdateSuccess()
@admin_router.get("/providers/{provider_id}/voices")
def get_provider_voices(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VoiceOption]:
"""Get available voices for a provider."""
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
if provider_db is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
if not provider_db.api_key:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
@admin_router.get("/voices")
def get_voices_by_type(
provider_type: str,
_: User = Depends(current_admin_user),
) -> list[VoiceOption]:
"""Get available voices for a provider type.
For providers like ElevenLabs and OpenAI, this fetches voices
without requiring an existing provider configuration.
"""
# Create a temporary VoiceProvider to get static voice list
temp_provider = VoiceProvider(
name="__temp__",
provider_type=provider_type,
)
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
return [VoiceOption(**voice) for voice in provider.get_available_voices()]

View File

@@ -0,0 +1,95 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
class VoiceProviderView(BaseModel):
"""Response model for voice provider listing."""
id: int
name: str
provider_type: str # "openai", "azure", "elevenlabs"
is_default_stt: bool
is_default_tts: bool
stt_model: str | None
tts_model: str | None
default_voice: str | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services.",
)
class VoiceProviderUpdateSuccess(BaseModel):
"""Simple status response for voice provider actions."""
status: str = "ok"
class VoiceOption(BaseModel):
"""Voice option returned by voice providers."""
id: str
name: str
class VoiceProviderUpsertRequest(BaseModel):
"""Request model for creating or updating a voice provider."""
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: str # "openai", "azure", "elevenlabs"
api_key: str | None = Field(
default=None,
description="API key for the provider.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
llm_provider_id: int | None = Field(
default=None,
description="If set, copies the API key from the specified LLM provider.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None
stt_model: str | None = None
tts_model: str | None = None
default_voice: str | None = None
activate_stt: bool = Field(
default=False,
description="If true, sets this provider as the default STT provider after upsert.",
)
activate_tts: bool = Field(
default=False,
description="If true, sets this provider as the default TTS provider after upsert.",
)
class VoiceProviderTestRequest(BaseModel):
"""Request model for testing a voice provider connection."""
provider_type: str
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None

View File

@@ -0,0 +1,250 @@
import secrets
from collections.abc import AsyncIterator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.db.voice import update_user_voice_settings
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.redis.redis_pool import store_ws_token
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
from onyx.server.manage.models import VoiceSettingsUpdateRequest
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Max audio file size: 25MB (Whisper limit)
MAX_AUDIO_SIZE = 25 * 1024 * 1024
# Chunk size for streaming uploads (8KB)
UPLOAD_READ_CHUNK_SIZE = 8192
class VoiceStatusResponse(BaseModel):
stt_enabled: bool
tts_enabled: bool
@router.get("/status")
def get_voice_status(
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> VoiceStatusResponse:
"""Check whether STT and TTS providers are configured and ready."""
stt_provider = fetch_default_stt_provider(db_session)
tts_provider = fetch_default_tts_provider(db_session)
return VoiceStatusResponse(
stt_enabled=stt_provider is not None and stt_provider.api_key is not None,
tts_enabled=tts_provider is not None and tts_provider.api_key is not None,
)
@router.post("/transcribe")
async def transcribe_audio(
audio: UploadFile = File(...),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Transcribe audio to text using the default STT provider."""
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No speech-to-text provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Voice provider API key not configured.",
)
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
chunks: list[bytes] = []
total = 0
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
total += len(chunk)
if total > MAX_AUDIO_SIZE:
raise OnyxError(
OnyxErrorCode.PAYLOAD_TOO_LARGE,
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
)
chunks.append(chunk)
audio_data = b"".join(chunks)
# Extract format from filename
filename = audio.filename or "audio.webm"
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
try:
text = await provider.transcribe(audio_data, audio_format)
return {"text": text}
except NotImplementedError as exc:
raise OnyxError(
OnyxErrorCode.NOT_IMPLEMENTED,
f"Speech-to-text not implemented for {provider_db.provider_type}.",
) from exc
except Exception as exc:
logger.error(f"Transcription failed: {exc}")
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
"Transcription failed. Please try again.",
) from exc
@router.post("/synthesize")
async def synthesize_speech(
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
user: User = Depends(current_user),
) -> StreamingResponse:
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
with get_session_with_current_tenant() as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.error("No TTS provider configured")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No text-to-speech provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
logger.error("TTS provider has no API key")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Voice provider API key not configured.",
)
# Use request voice or provider default
final_voice = voice or provider_db.default_voice
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
final_speed = (
speed
if speed is not None
else (
user.voice_playback_speed
if user.voice_playback_speed is not None
else 1.0
)
)
logger.info(
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
logger.error(f"Failed to get voice provider: {exc}")
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
# Session is now closed - streaming response won't hold DB connection
async def audio_stream() -> AsyncIterator[bytes]:
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@router.patch("/settings")
def update_voice_settings(
request: VoiceSettingsUpdateRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Update user's voice settings."""
update_user_voice_settings(
db_session=db_session,
user_id=user.id,
auto_send=request.auto_send,
auto_playback=request.auto_playback,
playback_speed=request.playback_speed,
)
return {"status": "ok"}
class WSTokenResponse(BaseModel):
token: str
@router.post("/ws-token")
async def get_ws_token(
user: User = Depends(current_user),
) -> WSTokenResponse:
"""
Generate a short-lived token for WebSocket authentication.
This token should be passed as a query parameter when connecting
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
The token expires after 60 seconds and is single-use.
Rate limited to 10 tokens per minute per user.
"""
token = secrets.token_urlsafe(32)
try:
await store_ws_token(token, str(user.id))
except WsTokenRateLimitExceeded:
raise OnyxError(
OnyxErrorCode.RATE_LIMITED,
"Too many token requests. Please wait before requesting another.",
)
return WSTokenResponse(token=token)

View File

@@ -0,0 +1,860 @@
"""WebSocket API for streaming speech-to-text and text-to-speech."""
import asyncio
import io
import json
import os
from collections.abc import MutableMapping
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from sqlalchemy.orm import Session
from onyx.auth.users import current_user_from_websocket
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
MIN_CHUNK_BYTES = 1500
VOICE_DISABLE_STREAMING_FALLBACK = (
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
)
# WebSocket size limits to prevent memory exhaustion attacks
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
class ChunkedTranscriber:
"""Fallback transcriber for providers without streaming support."""
def __init__(self, provider: Any, audio_format: str = "webm"):
self.provider = provider
self.audio_format = audio_format
self.chunk_buffer = io.BytesIO()
self.full_audio = io.BytesIO()
self.chunk_bytes = 0
self.transcripts: list[str] = []
async def add_chunk(self, chunk: bytes) -> str | None:
"""Add audio chunk. Returns transcript if enough audio accumulated."""
self.chunk_buffer.write(chunk)
self.full_audio.write(chunk)
self.chunk_bytes += len(chunk)
if self.chunk_bytes >= MIN_CHUNK_BYTES:
return await self._transcribe_chunk()
return None
async def _transcribe_chunk(self) -> str | None:
"""Transcribe current chunk and append to running transcript."""
audio_data = self.chunk_buffer.getvalue()
if not audio_data:
return None
try:
transcript = await self.provider.transcribe(audio_data, self.audio_format)
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
if transcript and transcript.strip():
self.transcripts.append(transcript.strip())
return " ".join(self.transcripts)
return None
except Exception as e:
logger.error(f"Transcription error: {e}")
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
return None
async def flush(self) -> str:
"""Get final transcript from full audio for best accuracy."""
full_audio_data = self.full_audio.getvalue()
if full_audio_data:
try:
transcript = await self.provider.transcribe(
full_audio_data, self.audio_format
)
if transcript and transcript.strip():
return transcript.strip()
except Exception as e:
logger.error(f"Final transcription error: {e}")
return " ".join(self.transcripts)
async def handle_streaming_transcription(
websocket: WebSocket,
transcriber: StreamingTranscriberProtocol,
) -> None:
"""Handle transcription using native streaming API."""
logger.info("Streaming transcription: starting handler")
last_transcript = ""
chunk_count = 0
total_bytes = 0
async def receive_transcripts() -> None:
"""Background task to receive and send transcripts."""
nonlocal last_transcript
logger.info("Streaming transcription: starting transcript receiver")
while True:
result: TranscriptResult | None = await transcriber.receive_transcript()
if result is None: # End of stream
logger.info("Streaming transcription: transcript stream ended")
break
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
if result.text and (result.text != last_transcript or result.is_vad_end):
last_transcript = result.text
logger.debug(
f"Streaming transcription: got transcript: {result.text[:50]}... "
f"(is_vad_end={result.is_vad_end})"
)
await websocket.send_json(
{
"type": "transcript",
"text": result.text,
"is_final": result.is_vad_end,
}
)
# Start receiving transcripts in background
receive_task = asyncio.create_task(receive_transcripts())
try:
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
# Enforce per-message size limit
if chunk_size > WS_MAX_MESSAGE_SIZE:
logger.warning(
f"Streaming transcription: message too large ({chunk_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Message too large"}
)
break
# Enforce total connection size limit
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
logger.warning(
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Total size limit exceeded"}
)
break
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
await transcriber.send_audio(message["bytes"])
elif "text" in message:
try:
data = json.loads(message["text"])
logger.debug(
f"Streaming transcription: received text message: {data}"
)
if data.get("type") == "end":
logger.info(
"Streaming transcription: end signal received, closing transcriber"
)
final_transcript = await transcriber.close()
receive_task.cancel()
logger.info(
"Streaming transcription: final transcript: "
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
elif data.get("type") == "reset":
# Reset accumulated transcript after auto-send
logger.info(
"Streaming transcription: reset signal received, clearing transcript"
)
transcriber.reset_transcript()
except json.JSONDecodeError:
logger.warning(
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
except Exception as e:
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
raise
finally:
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
logger.info(
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
async def handle_chunked_transcription(
websocket: WebSocket,
transcriber: ChunkedTranscriber,
) -> None:
"""Handle transcription using chunked batch API."""
logger.info("Chunked transcription: starting handler")
chunk_count = 0
total_bytes = 0
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
# Enforce per-message size limit
if chunk_size > WS_MAX_MESSAGE_SIZE:
logger.warning(
f"Chunked transcription: message too large ({chunk_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Message too large"}
)
break
# Enforce total connection size limit
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
logger.warning(
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Total size limit exceeded"}
)
break
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
transcript = await transcriber.add_chunk(message["bytes"])
if transcript:
logger.debug(
f"Chunked transcription: got transcript: {transcript[:50]}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": transcript,
"is_final": False,
}
)
elif "text" in message:
try:
data = json.loads(message["text"])
logger.debug(f"Chunked transcription: received text message: {data}")
if data.get("type") == "end":
logger.info("Chunked transcription: end signal received, flushing")
final_transcript = await transcriber.flush()
logger.info(
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
except json.JSONDecodeError:
logger.warning(
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
logger.info(
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
@router.websocket("/transcribe/stream")
async def websocket_transcribe(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming speech-to-text.
Protocol:
- Client sends binary audio chunks
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
- Client sends JSON {"type": "end"} to signal end
- Server responds with final transcript and closes
Authentication:
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket transcribe: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket transcribe: connection accepted")
except Exception as e:
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
return
streaming_transcriber = None
provider = None
try:
# Get STT provider
logger.info("WebSocket transcribe: fetching STT provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket transcribe: no default STT provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No speech-to-text provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket transcribe: STT provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Speech-to-text provider has no API key configured",
}
)
return
logger.info(
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
)
except ValueError as e:
logger.error(
f"WebSocket transcribe: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_stt():
logger.info("WebSocket transcribe: using native streaming STT")
try:
streaming_transcriber = await provider.create_streaming_transcriber()
logger.info(
"WebSocket transcribe: streaming transcriber created successfully"
)
await handle_streaming_transcription(websocket, streaming_transcriber)
except Exception as e:
logger.error(
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming STT failed: {e}"}
)
return
logger.info("WebSocket transcribe: falling back to chunked STT")
# Browser stream provides raw PCM16 chunks over WebSocket.
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
else:
# Fall back to chunked transcription
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming STT",
}
)
return
logger.info(
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
)
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
except WebSocketDisconnect:
logger.debug("WebSocket transcribe: client disconnected")
except Exception as e:
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
try:
# Send generic error to avoid leaking sensitive details
await websocket.send_json(
{"type": "error", "message": "An unexpected error occurred"}
)
except Exception:
pass
finally:
if streaming_transcriber:
try:
await streaming_transcriber.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket transcribe: connection closed")
async def handle_streaming_synthesis(
websocket: WebSocket,
synthesizer: StreamingSynthesizerProtocol,
) -> None:
"""Handle TTS using native streaming API."""
logger.info("Streaming synthesis: starting handler")
async def send_audio() -> None:
"""Background task to send audio chunks to client."""
chunk_count = 0
total_bytes = 0
try:
while True:
audio_chunk = await synthesizer.receive_audio()
if audio_chunk is None:
logger.info(
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
)
try:
await websocket.send_json({"type": "audio_done"})
logger.info("Streaming synthesis: sent audio_done to client")
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send audio_done: {e}"
)
break
if audio_chunk: # Skip empty chunks
chunk_count += 1
total_bytes += len(audio_chunk)
try:
await websocket.send_bytes(audio_chunk)
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send chunk: {e}"
)
break
except asyncio.CancelledError:
logger.info(
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
)
except Exception as e:
logger.error(f"Streaming synthesis: send_audio error: {e}")
send_task: asyncio.Task | None = None
disconnected = False
try:
while not disconnected:
try:
message = await websocket.receive()
except WebSocketDisconnect:
logger.info("Streaming synthesis: client disconnected")
break
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
if msg_type == "websocket.disconnect":
logger.info("Streaming synthesis: client disconnected")
disconnected = True
break
if "text" in message:
# Enforce text message size limit
msg_size = len(message["text"])
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
logger.warning(
f"Streaming synthesis: text message too large ({msg_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Message too large"}
)
break
try:
data = json.loads(message["text"])
if data.get("type") == "synthesize":
text = data.get("text", "")
# Enforce per-text size limit
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
logger.warning(
f"Streaming synthesis: text too long ({len(text)} chars)"
)
await websocket.send_json(
{"type": "error", "message": "Text too long"}
)
continue
if text:
# Start audio receiver on first text chunk so playback
# can begin before the full assistant response completes.
if send_task is None:
send_task = asyncio.create_task(send_audio())
logger.debug(
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
)
await synthesizer.send_text(text)
elif data.get("type") == "end":
logger.info("Streaming synthesis: end signal received")
# Ensure receiver is active even if no prior text chunks arrived.
if send_task is None:
send_task = asyncio.create_task(send_audio())
# Signal end of input
if hasattr(synthesizer, "flush"):
await synthesizer.flush()
# Wait for all audio to be sent
logger.info(
"Streaming synthesis: waiting for audio stream to complete"
)
try:
await asyncio.wait_for(send_task, timeout=60.0)
except asyncio.TimeoutError:
logger.warning(
"Streaming synthesis: timeout waiting for audio"
)
break
except json.JSONDecodeError:
logger.warning(
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
)
except WebSocketDisconnect:
logger.debug("Streaming synthesis: client disconnected during synthesis")
except Exception as e:
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
finally:
if send_task and not send_task.done():
logger.info("Streaming synthesis: waiting for send_task to finish")
try:
await asyncio.wait_for(send_task, timeout=30.0)
except asyncio.TimeoutError:
logger.warning("Streaming synthesis: timeout waiting for send_task")
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
logger.info("Streaming synthesis: handler finished")
async def handle_chunked_synthesis(
websocket: WebSocket,
provider: Any,
first_message: MutableMapping[str, Any] | None = None,
) -> None:
"""Fallback TTS handler using provider.synthesize_stream.
Args:
websocket: The WebSocket connection
provider: Voice provider instance
first_message: Optional first message already received (used when falling
back from streaming mode, where the first message was already consumed)
"""
logger.info("Chunked synthesis: starting handler")
text_buffer: list[str] = []
voice: str | None = None
speed = 1.0
# Process pre-received message if provided
pending_message = first_message
try:
while True:
if pending_message is not None:
message = pending_message
pending_message = None
else:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info("Chunked synthesis: client disconnected")
break
if "text" not in message:
continue
# Enforce text message size limit
msg_size = len(message["text"])
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
logger.warning(
f"Chunked synthesis: text message too large ({msg_size} bytes)"
)
await websocket.send_json(
{"type": "error", "message": "Message too large"}
)
break
try:
data = json.loads(message["text"])
except json.JSONDecodeError:
logger.warning(
"Chunked synthesis: failed to parse JSON: "
f"{message.get('text', '')[:100]}"
)
continue
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
if msg_data_type == "synthesize":
text = data.get("text", "")
# Enforce per-text size limit
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
logger.warning(
f"Chunked synthesis: text too long ({len(text)} chars)"
)
await websocket.send_json(
{"type": "error", "message": "Text too long"}
)
continue
if text:
text_buffer.append(text)
logger.debug(
f"Chunked synthesis: buffered text ({len(text)} chars), "
f"total buffered: {len(text_buffer)} chunks"
)
if isinstance(data.get("voice"), str) and data["voice"]:
voice = data["voice"]
if isinstance(data.get("speed"), (int, float)):
speed = float(data["speed"])
elif msg_data_type == "end":
logger.info("Chunked synthesis: end signal received")
full_text = " ".join(text_buffer).strip()
if not full_text:
await websocket.send_json({"type": "audio_done"})
logger.info("Chunked synthesis: no text, sent audio_done")
break
chunk_count = 0
total_bytes = 0
logger.info(
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
)
async for audio_chunk in provider.synthesize_stream(
full_text, voice=voice, speed=speed
):
if not audio_chunk:
continue
chunk_count += 1
total_bytes += len(audio_chunk)
await websocket.send_bytes(audio_chunk)
await websocket.send_json({"type": "audio_done"})
logger.info(
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
)
break
except WebSocketDisconnect:
logger.debug("Chunked synthesis: client disconnected")
except Exception as e:
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
raise
finally:
logger.info("Chunked synthesis: handler finished")
@router.websocket("/synthesize/stream")
async def websocket_synthesize(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming text-to-speech.
Protocol:
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
- Server sends binary audio chunks
- Server sends JSON: {"type": "audio_done"} when synthesis completes
- Client sends JSON {"type": "end"} to close connection
Authentication:
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket synthesize: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket synthesize: connection accepted")
except Exception as e:
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
return
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
provider = None
try:
# Get TTS provider
logger.info("WebSocket synthesize: fetching TTS provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket synthesize: no default TTS provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No text-to-speech provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket synthesize: TTS provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Text-to-speech provider has no API key configured",
}
)
return
logger.info(
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
)
except ValueError as e:
logger.error(
f"WebSocket synthesize: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_tts():
logger.info("WebSocket synthesize: using native streaming TTS")
message = None # Initialize to avoid UnboundLocalError in except block
try:
# Wait for initial config message with voice/speed
message = await websocket.receive()
voice = None
speed = 1.0
if "text" in message:
try:
data = json.loads(message["text"])
voice = data.get("voice")
speed = data.get("speed", 1.0)
except json.JSONDecodeError:
pass
streaming_synthesizer = await provider.create_streaming_synthesizer(
voice=voice, speed=speed
)
logger.info(
"WebSocket synthesize: streaming synthesizer created successfully"
)
await handle_streaming_synthesis(websocket, streaming_synthesizer)
except Exception as e:
logger.error(
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming TTS failed: {e}"}
)
return
logger.info(
"WebSocket synthesize: falling back to chunked TTS synthesis"
)
# Pass the first message so it's not lost in the fallback
await handle_chunked_synthesis(
websocket, provider, first_message=message
)
else:
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming TTS",
}
)
return
logger.info(
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
)
await handle_chunked_synthesis(websocket, provider)
except WebSocketDisconnect:
logger.debug("WebSocket synthesize: client disconnected")
except Exception as e:
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
try:
# Send generic error to avoid leaking sensitive details
await websocket.send_json(
{"type": "error", "message": "An unexpected error occurred"}
)
except Exception:
pass
finally:
if streaming_synthesizer:
try:
await streaming_synthesizer.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket synthesize: connection closed")

View File

@@ -0,0 +1,13 @@
from typing import Any
from pydantic import BaseModel
def shallow_model_dump(model_instance: BaseModel) -> dict[str, Any]:
"""Like model_dump(), but returns references to field values instead of
deep copies. Use with model_construct() to avoid unnecessary memory
duplication when building subclass instances."""
return {
field_name: getattr(model_instance, field_name)
for field_name in model_instance.__class__.model_fields
}

View File

@@ -140,6 +140,44 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
return validated_ip, hostname, port
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
"""
Validate a URL that will be used by backend outbound HTTP calls.
Returns:
A normalized URL string with surrounding whitespace removed.
Raises:
ValueError: If URL is malformed.
SSRFException: If URL fails SSRF checks.
"""
normalized_url = url.strip()
if not normalized_url:
raise ValueError("URL cannot be empty")
parsed = urlparse(normalized_url)
if parsed.scheme not in ("http", "https"):
raise SSRFException(
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
)
if not parsed.hostname:
raise ValueError("URL must contain a hostname")
if parsed.username or parsed.password:
raise SSRFException("URLs with embedded credentials are not allowed.")
hostname = parsed.hostname.lower()
if hostname in BLOCKED_HOSTNAMES:
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
if not allow_private_network:
_validate_and_resolve_url(normalized_url)
return normalized_url
MAX_REDIRECTS = 10

View File

View File

@@ -0,0 +1,70 @@
from onyx.db.models import VoiceProvider
from onyx.voice.interface import VoiceProviderInterface
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
"""
Factory function to get the appropriate voice provider implementation.
Args:
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
Returns:
VoiceProviderInterface implementation
Raises:
ValueError: If provider_type is not supported
"""
provider_type = provider.provider_type.lower()
# Handle both SensitiveValue (from DB) and plain string (from temp model)
if provider.api_key is None:
api_key = None
elif hasattr(provider.api_key, "get_value"):
# SensitiveValue from database
api_key = provider.api_key.get_value(apply_mask=False)
else:
# Plain string from temporary model
api_key = provider.api_key # type: ignore[assignment]
api_base = provider.api_base
custom_config = provider.custom_config
stt_model = provider.stt_model
tts_model = provider.tts_model
default_voice = provider.default_voice
if provider_type == "openai":
from onyx.voice.providers.openai import OpenAIVoiceProvider
return OpenAIVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "azure":
from onyx.voice.providers.azure import AzureVoiceProvider
return AzureVoiceProvider(
api_key=api_key,
api_base=api_base,
custom_config=custom_config or {},
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "elevenlabs":
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
return ElevenLabsVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
else:
raise ValueError(f"Unsupported voice provider type: {provider_type}")

View File

@@ -0,0 +1,182 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import AsyncIterator
from typing import Protocol
from pydantic import BaseModel
class TranscriptResult(BaseModel):
"""Result from streaming transcription."""
text: str
"""The accumulated transcript text."""
is_vad_end: bool = False
"""True if VAD detected end of speech (silence). Use for auto-send."""
class StreamingTranscriberProtocol(Protocol):
"""Protocol for streaming transcription sessions."""
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
...
async def receive_transcript(self) -> TranscriptResult | None:
"""
Receive next transcript update.
Returns:
TranscriptResult with accumulated text and VAD status, or None when stream ends.
"""
...
async def close(self) -> str:
"""Close the session and return final transcript."""
...
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
...
class StreamingSynthesizerProtocol(Protocol):
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
async def connect(self) -> None:
"""Establish connection to TTS provider."""
...
async def send_text(self, text: str) -> None:
"""Send text to be synthesized."""
...
async def receive_audio(self) -> bytes | None:
"""
Receive next audio chunk.
Returns:
Audio bytes, or None when stream ends.
"""
...
async def flush(self) -> None:
"""Signal end of text input and wait for pending audio."""
...
async def close(self) -> None:
"""Close the session."""
...
class VoiceProviderInterface(ABC):
"""Abstract base class for voice providers (STT and TTS)."""
@abstractmethod
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Convert audio to text (Speech-to-Text).
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
@abstractmethod
def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio stream (Text-to-Speech).
Streams audio chunks progressively for lower latency playback.
Args:
text: Text to convert to speech
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks
"""
@abstractmethod
async def validate_credentials(self) -> None:
"""
Validate that the provider credentials are correct by making a
lightweight API call. Raises on failure.
"""
@abstractmethod
def get_available_voices(self) -> list[dict[str, str]]:
"""
Get list of available voices for this provider.
Returns:
List of voice dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_stt_models(self) -> list[dict[str, str]]:
"""
Get list of available STT models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_tts_models(self) -> list[dict[str, str]]:
"""
Get list of available TTS models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
def supports_streaming_stt(self) -> bool:
"""Returns True if this provider supports streaming STT."""
return False
def supports_streaming_tts(self) -> bool:
"""Returns True if this provider supports real-time streaming TTS."""
return False
async def create_streaming_transcriber(
self, audio_format: str = "webm"
) -> StreamingTranscriberProtocol:
"""
Create a streaming transcription session.
Args:
audio_format: Audio format being sent (e.g., "webm", "pcm16")
Returns:
A streaming transcriber that can send audio chunks and receive transcripts
Raises:
NotImplementedError: If streaming STT is not supported
"""
raise NotImplementedError("Streaming STT not supported by this provider")
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> "StreamingSynthesizerProtocol":
"""
Create a streaming TTS session for real-time audio synthesis.
Args:
voice: Voice identifier
speed: Playback speed multiplier
Returns:
A streaming synthesizer that can send text and receive audio chunks
Raises:
NotImplementedError: If streaming TTS is not supported
"""
raise NotImplementedError("Streaming TTS not supported by this provider")

View File

View File

@@ -0,0 +1,626 @@
"""Azure Speech Services voice provider for STT and TTS.
Azure supports:
- **STT**: Batch transcription via REST API (audio/wav POST) and real-time
streaming via the Azure Speech SDK (push audio stream with continuous
recognition). The SDK handles VAD natively through its recognizing/recognized
events.
- **TTS**: SSML-based synthesis via REST API (streaming response) and real-time
synthesis via the Speech SDK. Text is escaped with ``xml.sax.saxutils.escape``
and attributes with ``quoteattr`` to prevent SSML injection.
Both modes support Azure cloud endpoints (region-based URLs) and self-hosted
Speech containers (custom endpoint URLs). The ``speech_region`` is validated to
contain only ``[a-z0-9-]`` to prevent URL injection.
The Azure Speech SDK (``azure-cognitiveservices-speech``) is an optional C
extension dependency — it is imported lazily inside streaming methods so the
provider can still be instantiated and used for REST-based operations without it.
See https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/
for API reference.
"""
import asyncio
import io
import re
import struct
import wave
from collections.abc import AsyncIterator
from typing import Any
from urllib.parse import urlparse
from xml.sax.saxutils import escape
from xml.sax.saxutils import quoteattr
import aiohttp
from onyx.utils.logger import setup_logger
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# SSML namespace — W3C standard for Speech Synthesis Markup Language.
# This is a fixed W3C specification and will not change.
SSML_NAMESPACE = "http://www.w3.org/2001/10/synthesis"
# Common Azure Neural voices
AZURE_VOICES = [
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
]
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using Azure Speech SDK."""
def __init__(
self,
api_key: str,
region: str | None = None,
endpoint: str | None = None,
input_sample_rate: int = 24000,
target_sample_rate: int = 16000,
):
self.api_key = api_key
self.region = region
self.endpoint = endpoint
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._accumulated_transcript = ""
self._recognizer: Any = None
self._audio_stream: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech recognizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk # type: ignore
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming STT. "
"Install `azure-cognitiveservices-speech`."
) from e
self._loop = asyncio.get_running_loop()
# Use endpoint for self-hosted containers, region for Azure cloud
if self.endpoint:
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
endpoint=self.endpoint,
)
else:
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
audio_format = speechsdk.audio.AudioStreamFormat(
samples_per_second=16000,
bits_per_sample=16,
channels=1,
)
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
self._recognizer = speechsdk.SpeechRecognizer(
speech_config=speech_config,
audio_config=audio_config,
)
transcriber = self
def on_recognizing(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
full_text = transcriber._accumulated_transcript
if full_text:
full_text += " " + evt.result.text
else:
full_text = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(text=full_text, is_vad_end=False),
)
def on_recognized(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
if transcriber._accumulated_transcript:
transcriber._accumulated_transcript += " " + evt.result.text
else:
transcriber._accumulated_transcript = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(
text=transcriber._accumulated_transcript, is_vad_end=True
),
)
self._recognizer.recognizing.connect(on_recognizing)
self._recognizer.recognized.connect(on_recognized)
self._recognizer.start_continuous_recognition_async()
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to Azure."""
if self._audio_stream and not self._closed:
self._audio_stream.write(self._resample_pcm16(chunk))
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
if self.input_sample_rate == self.target_sample_rate:
return data
num_samples = len(data) // 2
if num_samples == 0:
return b""
samples = list(struct.unpack(f"<{num_samples}h", data))
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
resampled: list[int] = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Stop recognition and return final transcript."""
self._closed = True
if self._recognizer:
self._recognizer.stop_continuous_recognition_async()
if self._audio_stream:
self._audio_stream.close()
self._loop = None
return self._accumulated_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript."""
self._accumulated_transcript = ""
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using Azure Speech SDK."""
def __init__(
self,
api_key: str,
region: str | None = None,
endpoint: str | None = None,
voice: str = "en-US-JennyNeural",
speed: float = 1.0,
):
self._logger = setup_logger()
self.api_key = api_key
self.region = region
self.endpoint = endpoint
self.voice = voice
self.speed = max(0.5, min(2.0, speed))
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._synthesizer: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech synthesizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming TTS. "
"Install `azure-cognitiveservices-speech`."
) from e
self._logger.info("AzureStreamingSynthesizer: connecting")
# Store the event loop for thread-safe queue operations
self._loop = asyncio.get_running_loop()
# Use endpoint for self-hosted containers, region for Azure cloud
if self.endpoint:
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
endpoint=self.endpoint,
)
else:
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
speech_config.speech_synthesis_voice_name = self.voice
# Use MP3 format for streaming - compatible with MediaSource Extensions
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
)
# Create synthesizer with pull audio output stream
self._synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config,
audio_config=None, # We'll manually handle audio
)
# Connect to synthesis events
self._synthesizer.synthesizing.connect(self._on_synthesizing)
self._synthesizer.synthesis_completed.connect(self._on_completed)
self._logger.info("AzureStreamingSynthesizer: connected")
def _on_synthesizing(self, evt: Any) -> None:
"""Called when audio chunk is available (runs in Azure SDK thread)."""
if evt.result.audio_data and self._loop and not self._closed:
# Thread-safe way to put item in async queue
self._loop.call_soon_threadsafe(
self._audio_queue.put_nowait, evt.result.audio_data
)
def _on_completed(self, _evt: Any) -> None:
"""Called when synthesis is complete (runs in Azure SDK thread)."""
if self._loop and not self._closed:
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
async def send_text(self, text: str) -> None:
"""Send text to be synthesized using SSML for prosody control."""
if self._synthesizer and not self._closed:
# Build SSML with prosody for speed control
rate = f"{int((self.speed - 1) * 100):+d}%"
escaped_text = escape(text)
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
<voice name={quoteattr(self.voice)}>
<prosody rate='{rate}'>{escaped_text}</prosody>
</voice>
</speak>"""
# Use speak_ssml_async for SSML support (includes speed/prosody)
self._synthesizer.speak_ssml_async(ssml)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for pending audio."""
# Azure SDK handles flushing automatically
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._synthesizer:
self._synthesizer.synthesis_completed.disconnect_all()
self._synthesizer.synthesizing.disconnect_all()
self._loop = None
class AzureVoiceProvider(VoiceProviderInterface):
"""Azure Speech Services voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None,
custom_config: dict[str, Any],
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.custom_config = custom_config
raw_speech_region = (
custom_config.get("speech_region")
or self._extract_speech_region_from_uri(api_base)
or ""
)
self.speech_region = self._validate_speech_region(raw_speech_region)
self.stt_model = stt_model
self.tts_model = tts_model
self.default_voice = default_voice or "en-US-JennyNeural"
@staticmethod
def _is_azure_cloud_url(uri: str | None) -> bool:
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
if not uri:
return False
try:
hostname = (urlparse(uri).hostname or "").lower()
except ValueError:
return False
return hostname.endswith(
(
".speech.microsoft.com",
".api.cognitive.microsoft.com",
".cognitiveservices.azure.com",
)
)
@staticmethod
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
"""Extract Azure speech region from endpoint URI.
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
name, not the region. For custom domains, the region must be specified
explicitly via custom_config["speech_region"].
"""
if not uri:
return None
# Accepted examples:
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
# - https://westus.api.cognitive.microsoft.com/
#
# NOT supported (requires explicit speech_region config):
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
try:
hostname = (urlparse(uri).hostname or "").lower()
except ValueError:
return None
stt_tts_match = re.match(
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
)
if stt_tts_match:
return stt_tts_match.group(1)
api_match = re.match(
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
)
if api_match:
return api_match.group(1)
return None
@staticmethod
def _validate_speech_region(speech_region: str) -> str:
normalized_region = speech_region.strip().lower()
if not normalized_region:
return ""
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
raise ValueError(
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
)
return normalized_region
def _get_stt_url(self) -> str:
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
if self.api_base and not self._is_azure_cloud_url(self.api_base):
# Self-hosted container endpoint
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
# Azure cloud endpoint
return (
f"https://{self.speech_region}.stt.speech.microsoft.com/"
"speech/recognition/conversation/cognitiveservices/v1"
)
def _get_tts_url(self) -> str:
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
if self.api_base and not self._is_azure_cloud_url(self.api_base):
# Self-hosted container endpoint
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
# Azure cloud endpoint
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
def _is_self_hosted(self) -> bool:
"""Check if using self-hosted container vs Azure cloud."""
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
@staticmethod
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
"""Wrap raw PCM16 mono bytes into a WAV container."""
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
return buffer.getvalue()
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
if not self.api_key:
raise ValueError("Azure API key required for STT")
if not self._is_self_hosted() and not self.speech_region:
raise ValueError("Azure speech region required for STT (cloud mode)")
normalized_format = audio_format.lower()
payload = audio_data
content_type = f"audio/{normalized_format}"
# WebSocket chunked fallback sends raw PCM16 bytes.
if normalized_format in {"pcm", "pcm16", "raw"}:
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
content_type = "audio/wav"
elif normalized_format in {"wav", "wave"}:
content_type = "audio/wav"
elif normalized_format == "webm":
content_type = "audio/webm; codecs=opus"
url = self._get_stt_url()
params = {"language": "en-US", "format": "detailed"}
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": content_type,
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(
url, params=params, headers=headers, data=payload
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure STT failed: {error_text}")
result = await response.json()
if result.get("RecognitionStatus") != "Success":
return ""
nbest = result.get("NBest") or []
if nbest and isinstance(nbest, list):
display = nbest[0].get("Display")
if isinstance(display, str):
return display
display_text = result.get("DisplayText", "")
return display_text if isinstance(display_text, str) else ""
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using Azure TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice name (defaults to provider's default voice)
speed: Playback speed multiplier (0.5 to 2.0)
Yields:
Audio data chunks (mp3 format)
"""
if not self.api_key:
raise ValueError("Azure API key required for TTS")
if not self._is_self_hosted() and not self.speech_region:
raise ValueError("Azure speech region required for TTS (cloud mode)")
voice_name = voice or self.default_voice
# Clamp speed to valid range and convert to rate format
speed = max(0.5, min(2.0, speed))
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
# Build SSML with escaped text and quoted attributes to prevent injection
escaped_text = escape(text)
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
<voice name={quoteattr(voice_name)}>
<prosody rate='{rate}'>{escaped_text}</prosody>
</voice>
</speak>"""
url = self._get_tts_url()
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
"User-Agent": "Onyx",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=ssml) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
async for chunk in response.content.iter_chunked(8192):
if chunk:
yield chunk
async def validate_credentials(self) -> None:
"""Validate Azure credentials by listing available voices."""
if not self.api_key:
raise ValueError("Azure API key required")
if not self._is_self_hosted() and not self.speech_region:
raise ValueError("Azure speech region required (cloud mode)")
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
if self._is_self_hosted():
url = f"{(self.api_base or '').rstrip('/')}/cognitiveservices/voices/list"
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(
f"Azure credential validation failed: {error_text}"
)
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common Azure Neural voices."""
return AZURE_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "default", "name": "Azure Speech Recognition"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "neural", "name": "Neural TTS"},
]
def supports_streaming_stt(self) -> bool:
"""Azure supports streaming STT via Speech SDK."""
return True
def supports_streaming_tts(self) -> bool:
"""Azure supports real-time streaming TTS via Speech SDK."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> AzureStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
if not self._is_self_hosted() and not self.speech_region:
raise ValueError(
"Speech region required for Azure streaming transcription (cloud mode)"
)
# Use endpoint for self-hosted, region for cloud
transcriber = AzureStreamingTranscriber(
api_key=self.api_key,
region=self.speech_region if not self._is_self_hosted() else None,
endpoint=self.api_base if self._is_self_hosted() else None,
input_sample_rate=24000,
target_sample_rate=16000,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> AzureStreamingSynthesizer:
"""Create a streaming TTS session."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
if not self._is_self_hosted() and not self.speech_region:
raise ValueError(
"Speech region required for Azure streaming TTS (cloud mode)"
)
# Use endpoint for self-hosted, region for cloud
synthesizer = AzureStreamingSynthesizer(
api_key=self.api_key,
region=self.speech_region if not self._is_self_hosted() else None,
endpoint=self.api_base if self._is_self_hosted() else None,
voice=voice or self.default_voice or "en-US-JennyNeural",
speed=speed,
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,858 @@
"""ElevenLabs voice provider for STT and TTS.
ElevenLabs supports:
- **STT**: Scribe API (batch via REST, streaming via WebSocket with Scribe v2 Realtime).
The streaming endpoint sends base64-encoded PCM16 audio chunks and receives JSON
transcript messages (partial_transcript, committed_transcript, utterance_end).
- **TTS**: Text-to-speech via REST streaming and WebSocket stream-input.
The WebSocket variant accepts incremental text chunks and returns audio in order,
enabling low-latency playback before the full text is available.
See https://elevenlabs.io/docs for API reference.
"""
import asyncio
import base64
import json
from collections.abc import AsyncIterator
from enum import StrEnum
from typing import Any
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# Default ElevenLabs API base URL
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
# Default sample rates for STT streaming
DEFAULT_INPUT_SAMPLE_RATE = 24000 # What the browser frontend sends
DEFAULT_TARGET_SAMPLE_RATE = 16000 # What ElevenLabs Scribe expects
# Default streaming TTS output format
DEFAULT_TTS_OUTPUT_FORMAT = "mp3_44100_64"
# Default TTS voice settings
DEFAULT_VOICE_STABILITY = 0.5
DEFAULT_VOICE_SIMILARITY_BOOST = 0.75
# Chunk length schedule for streaming TTS (optimized for real-time playback)
DEFAULT_CHUNK_LENGTH_SCHEDULE = [120, 160, 250, 290]
# Default STT streaming VAD configuration
DEFAULT_VAD_SILENCE_THRESHOLD_SECS = 1.0
DEFAULT_VAD_THRESHOLD = 0.4
DEFAULT_MIN_SPEECH_DURATION_MS = 100
DEFAULT_MIN_SILENCE_DURATION_MS = 300
class ElevenLabsSTTMessageType(StrEnum):
"""Message types from ElevenLabs Scribe Realtime STT API."""
SESSION_STARTED = "session_started"
PARTIAL_TRANSCRIPT = "partial_transcript"
COMMITTED_TRANSCRIPT = "committed_transcript"
UTTERANCE_END = "utterance_end"
SESSION_ENDED = "session_ended"
ERROR = "error"
class ElevenLabsTTSMessageType(StrEnum):
"""Message types from ElevenLabs stream-input TTS API."""
AUDIO = "audio"
ERROR = "error"
def _http_to_ws_url(http_url: str) -> str:
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
if http_url.startswith("https://"):
return "wss://" + http_url[8:]
elif http_url.startswith("http://"):
return "ws://" + http_url[7:]
return http_url
# Common ElevenLabs voices
ELEVENLABS_VOICES = [
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
]
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
def __init__(
self,
api_key: str,
model: str = "scribe_v2_realtime",
input_sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE,
target_sample_rate: int = DEFAULT_TARGET_SAMPLE_RATE,
language_code: str = "en",
api_base: str | None = None,
):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self.language_code = language_code
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._final_transcript = ""
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs."""
self._logger.info(
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
)
self._session = aiohttp.ClientSession()
# VAD is configured via query parameters.
# commit_strategy=vad enables automatic transcript commit on silence detection.
# These params are part of the ElevenLabs Scribe Realtime API contract:
# https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
url = (
f"{ws_base}/v1/speech-to-text/realtime"
f"?model_id={self.model}"
f"&sample_rate={self.target_sample_rate}"
f"&language_code={self.language_code}"
f"&commit_strategy=vad"
f"&vad_silence_threshold_secs={DEFAULT_VAD_SILENCE_THRESHOLD_SECS}"
f"&vad_threshold={DEFAULT_VAD_THRESHOLD}"
f"&min_speech_duration_ms={DEFAULT_MIN_SPEECH_DURATION_MS}"
f"&min_silence_duration_ms={DEFAULT_MIN_SILENCE_DURATION_MS}"
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connecting to {url} "
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
)
try:
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connected successfully, "
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
)
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
)
if self._session:
await self._session.close()
raise
# Start receiving transcripts in background
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts from WebSocket."""
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
if not self._ws:
self._logger.warning(
"ElevenLabsStreamingTranscriber: no WebSocket connection"
)
return
try:
async for msg in self._ws:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
)
if msg.type == aiohttp.WSMsgType.TEXT:
parsed_data: Any = None
data: dict[str, Any]
try:
parsed_data = json.loads(msg.data)
except json.JSONDecodeError:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
)
continue
if not isinstance(parsed_data, dict):
self._logger.error(
"ElevenLabsStreamingTranscriber: expected object JSON payload"
)
continue
data = parsed_data
# ElevenLabs uses message_type field - fail fast if missing
if "message_type" not in data and "type" not in data:
self._logger.error(
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
)
continue
msg_type = data.get("message_type", data.get("type", ""))
self._logger.info(
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
)
# Check for error in various formats
if "error" in data or msg_type == ElevenLabsSTTMessageType.ERROR:
error_msg = data.get("error", data.get("message", data))
self._logger.error(
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
)
continue
# Handle message types from ElevenLabs Scribe Realtime API.
# See https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
if msg_type == ElevenLabsSTTMessageType.SESSION_STARTED:
self._logger.info(
f"ElevenLabsStreamingTranscriber: session started, "
f"id={data.get('session_id')}, config={data.get('config')}"
)
elif msg_type == ElevenLabsSTTMessageType.PARTIAL_TRANSCRIPT:
# Interim result — updated as more audio is processed
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=False)
)
elif msg_type == ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT:
# Final transcript for the current utterance (VAD detected end)
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == ElevenLabsSTTMessageType.UTTERANCE_END:
# VAD detected end of speech (may carry text or be empty)
text = data.get("text", "") or self._final_transcript
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == ElevenLabsSTTMessageType.SESSION_ENDED:
self._logger.info(
"ElevenLabsStreamingTranscriber: session ended"
)
break
else:
# Log unhandled message types with full data for debugging
self._logger.warning(
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
)
elif msg.type == aiohttp.WSMsgType.CLOSED:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
"ElevenLabsStreamingTranscriber: WebSocket closed by "
f"server, close_code={close_code}"
)
break
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
)
break
elif msg.type == aiohttp.WSMsgType.CLOSE:
self._logger.info(
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
)
break
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
exc_info=True,
)
finally:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
)
await self._transcript_queue.put(None) # Signal end
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
import struct
if self.input_sample_rate == self.target_sample_rate:
return data
# Parse int16 samples
num_samples = len(data) // 2
samples = list(struct.unpack(f"<{num_samples}h", data))
# Calculate resampling ratio
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
# Linear interpolation resampling
resampled = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
# Clamp to int16 range
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
if not self._ws:
self._logger.warning("send_audio: no WebSocket connection")
return
if self._closed:
self._logger.warning("send_audio: transcriber is closed")
return
if self._ws.closed:
self._logger.warning(
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
)
return
try:
# Resample from input rate (24kHz) to target rate (16kHz)
resampled = self._resample_pcm16(chunk)
# ElevenLabs expects input_audio_chunk message format with audio_base_64
audio_b64 = base64.b64encode(resampled).decode("utf-8")
message = {
"message_type": "input_audio_chunk",
"audio_base_64": audio_b64,
"sample_rate": self.target_sample_rate,
}
self._logger.info(
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
)
await self._ws.send_str(json.dumps(message))
self._logger.info("send_audio: message sent successfully")
except Exception as e:
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
raise
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript. Returns None when done."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(
text="", is_vad_end=False
) # No transcript yet, but not done
async def close(self) -> str:
"""Close the session and return final transcript."""
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
self._closed = True
if self._ws and not self._ws.closed:
try:
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
self._logger.info(
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
)
await self._ws.close()
except Exception as e:
self._logger.debug(f"Error closing WebSocket: {e}")
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session and not self._session.closed:
await self._session.close()
return self._final_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._final_transcript = ""
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using ElevenLabs WebSocket API.
Uses ElevenLabs' stream-input WebSocket which processes text as one
continuous stream and returns audio in order.
"""
def __init__(
self,
api_key: str,
voice_id: str,
model_id: str = "eleven_multilingual_v2",
output_format: str = "mp3_44100_64",
api_base: str | None = None,
speed: float = 1.0,
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice_id = voice_id
self.model_id = model_id
self.output_format = output_format
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
self.speed = speed
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs TTS."""
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# WebSocket URL for streaming input TTS with output format for streaming compatibility
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
url = (
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
f"?model_id={self.model_id}&output_format={self.output_format}"
)
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
# Send initial configuration with generation settings optimized for streaming.
# Note: API key is sent via header only (not in body to avoid log exposure).
# See https://elevenlabs.io/docs/api-reference/text-to-speech/stream-input
await self._ws.send_str(
json.dumps(
{
"text": " ", # Initial space to start the stream
"voice_settings": {
"stability": DEFAULT_VOICE_STABILITY,
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
"speed": self.speed,
},
"generation_config": {
"chunk_length_schedule": DEFAULT_CHUNK_LENGTH_SCHEDULE,
},
}
)
)
# Start receiving audio in background
self._receive_task = asyncio.create_task(self._receive_loop())
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
async def _receive_loop(self) -> None:
"""Background task to receive audio chunks from WebSocket.
Audio is returned in order as one continuous stream.
"""
if not self._ws:
return
chunk_count = 0
total_bytes = 0
try:
async for msg in self._ws:
if self._closed:
self._logger.info(
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
"receive loop"
)
break
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
# Process audio if present
if "audio" in data and data["audio"]:
audio_bytes = base64.b64decode(data["audio"])
chunk_count += 1
total_bytes += len(audio_bytes)
await self._audio_queue.put(audio_bytes)
# Check isFinal separately - a message can have both audio AND isFinal
if "isFinal" in data:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
)
if data.get("isFinal"):
self._logger.info(
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
)
await self._audio_queue.put(None)
# Check for errors
if "error" in data or data.get("type") == "error":
self._logger.error(
f"ElevenLabsStreamingSynthesizer: received error: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
chunk_count += 1
total_bytes += len(msg.data)
await self._audio_queue.put(msg.data)
elif msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
):
self._logger.info(
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
)
break
except Exception as e:
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
finally:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
)
await self._audio_queue.put(None) # Signal end of stream
async def send_text(self, text: str) -> None:
"""Send text to be synthesized.
ElevenLabs processes text as a continuous stream and returns
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
and only force generation when flush() is called at the end.
Args:
text: Text to synthesize
"""
if self._ws and not self._closed and text.strip():
self._logger.info(
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
)
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
# Don't trigger generation here - wait for flush() at the end
await self._ws.send_str(
json.dumps(
{
"text": text + " ", # Space for natural speech flow
}
)
)
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
if self._ws and not self._closed:
# Send empty string to signal end of input
# ElevenLabs will generate any remaining buffered text,
# send all audio chunks, send isFinal, then close the connection
self._logger.info(
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
)
await self._ws.send_str(json.dumps({"text": ""}))
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping flush - "
f"ws={self._ws is not None}, closed={self._closed}"
)
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._ws:
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
# Valid ElevenLabs model IDs
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
ELEVENLABS_TTS_MODELS = {
"eleven_multilingual_v2",
"eleven_turbo_v2_5",
"eleven_monolingual_v1",
"eleven_flash_v2_5",
"eleven_flash_v2",
}
class ElevenLabsVoiceProvider(VoiceProviderInterface):
"""ElevenLabs voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
# Validate and default models - use valid ElevenLabs model IDs
self.stt_model = (
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
)
self.tts_model = (
tts_model
if tts_model in ELEVENLABS_TTS_MODELS
else "eleven_multilingual_v2"
)
self.default_voice = default_voice
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using ElevenLabs Speech-to-Text API.
Args:
audio_data: Raw audio bytes
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
Returns:
Transcribed text
"""
if not self.api_key:
raise ValueError("ElevenLabs API key required for transcription")
from onyx.utils.logger import setup_logger
logger = setup_logger()
url = f"{self.api_base}/v1/speech-to-text"
# Map common formats to MIME types
mime_types = {
"webm": "audio/webm",
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"m4a": "audio/mp4",
}
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
headers = {
"xi-api-key": self.api_key,
}
# ElevenLabs expects multipart form data
form_data = aiohttp.FormData()
form_data.add_field(
"audio",
audio_data,
filename=f"audio.{audio_format}",
content_type=mime_type,
)
# For batch STT, use scribe_v1 (not the realtime model)
batch_model = (
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
)
form_data.add_field("model_id", batch_model)
logger.info(
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
)
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=form_data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs transcribe failed: {error_text}")
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
result = await response.json()
text = result.get("text", "")
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
return text
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using ElevenLabs TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice ID (defaults to provider's default voice or Rachel)
speed: Playback speed multiplier
Yields:
Audio data chunks (mp3 format)
"""
from onyx.utils.logger import setup_logger
logger = setup_logger()
if not self.api_key:
raise ValueError("ElevenLabs API key required for TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
logger.info(
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
)
headers = {
"xi-api-key": self.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
}
payload = {
"text": text,
"model_id": self.tts_model,
"voice_settings": {
"stability": DEFAULT_VOICE_STABILITY,
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
"speed": speed,
},
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
logger.info(
f"ElevenLabs TTS: got response status={response.status}, "
f"content-type={response.headers.get('content-type')}"
)
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs TTS failed: {error_text}")
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
chunk_count = 0
total_bytes = 0
async for chunk in response.content.iter_chunked(8192):
if chunk:
chunk_count += 1
total_bytes += len(chunk)
yield chunk
logger.info(
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
f"{total_bytes} total bytes"
)
async def validate_credentials(self) -> None:
"""Validate ElevenLabs API key by fetching user info."""
if not self.api_key:
raise ValueError("ElevenLabs API key required")
headers = {"xi-api-key": self.api_key}
async with aiohttp.ClientSession() as session:
async with session.get(
f"{self.api_base}/v1/user", headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(
f"ElevenLabs credential validation failed: {error_text}"
)
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common ElevenLabs voices."""
return ELEVENLABS_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
]
def supports_streaming_stt(self) -> bool:
"""ElevenLabs supports streaming via Scribe Realtime API."""
return True
def supports_streaming_tts(self) -> bool:
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> ElevenLabsStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
# ElevenLabs realtime STT requires scribe_v2_realtime model.
# Frontend sends PCM16 at DEFAULT_INPUT_SAMPLE_RATE (24kHz),
# but ElevenLabs expects DEFAULT_TARGET_SAMPLE_RATE (16kHz).
# The transcriber resamples automatically.
transcriber = ElevenLabsStreamingTranscriber(
api_key=self.api_key,
model="scribe_v2_realtime",
input_sample_rate=DEFAULT_INPUT_SAMPLE_RATE,
target_sample_rate=DEFAULT_TARGET_SAMPLE_RATE,
language_code="en",
api_base=self.api_base,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> ElevenLabsStreamingSynthesizer:
"""Create a streaming TTS session."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
synthesizer = ElevenLabsStreamingSynthesizer(
api_key=self.api_key,
voice_id=voice_id,
model_id=self.tts_model,
output_format=DEFAULT_TTS_OUTPUT_FORMAT,
api_base=self.api_base,
speed=speed,
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,626 @@
"""OpenAI voice provider for STT and TTS.
OpenAI supports:
- **STT**: Whisper (batch transcription via REST) and Realtime API (streaming
transcription via WebSocket with server-side VAD). Audio is sent as base64-encoded
PCM16 at 24kHz mono. The Realtime API returns transcript deltas and completed
transcription events per VAD-detected utterance.
- **TTS**: HTTP streaming endpoint that returns audio chunks progressively.
Supported models: tts-1 (standard) and tts-1-hd (high quality).
See https://platform.openai.com/docs for API reference.
"""
import asyncio
import base64
import io
import json
from collections.abc import AsyncIterator
from enum import StrEnum
from typing import TYPE_CHECKING
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
if TYPE_CHECKING:
from openai import AsyncOpenAI
# Default OpenAI API base URL
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
class OpenAIRealtimeMessageType(StrEnum):
"""Message types from OpenAI Realtime transcription API."""
ERROR = "error"
SPEECH_STARTED = "input_audio_buffer.speech_started"
SPEECH_STOPPED = "input_audio_buffer.speech_stopped"
BUFFER_COMMITTED = "input_audio_buffer.committed"
TRANSCRIPTION_DELTA = "conversation.item.input_audio_transcription.delta"
TRANSCRIPTION_COMPLETED = "conversation.item.input_audio_transcription.completed"
SESSION_CREATED = "transcription_session.created"
SESSION_UPDATED = "transcription_session.updated"
ITEM_CREATED = "conversation.item.created"
def _http_to_ws_url(http_url: str) -> str:
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
if http_url.startswith("https://"):
return "wss://" + http_url[8:]
elif http_url.startswith("http://"):
return "ws://" + http_url[7:]
return http_url
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using OpenAI Realtime API."""
def __init__(
self,
api_key: str,
model: str = "whisper-1",
api_base: str | None = None,
):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"OpenAIStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._current_turn_transcript = "" # Transcript for current VAD turn
self._accumulated_transcript = "" # Accumulated across all turns
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to OpenAI Realtime API."""
self._session = aiohttp.ClientSession()
# OpenAI Realtime transcription endpoint
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
url = f"{ws_base}/v1/realtime?intent=transcription"
headers = {
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Beta": "realtime=v1",
}
try:
self._ws = await self._session.ws_connect(url, headers=headers)
self._logger.info("Connected to OpenAI Realtime API")
except Exception as e:
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
raise
# Configure the session for transcription
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
config_message = {
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
"input_audio_transcription": {
"model": self.model,
},
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 500,
},
},
}
await self._ws.send_str(json.dumps(config_message))
self._logger.info(f"Sent config for model: {self.model} with server VAD")
# Start receiving transcripts
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts."""
if not self._ws:
return
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
msg_type = data.get("type", "")
self._logger.debug(f"Received message type: {msg_type}")
# Handle errors
if msg_type == OpenAIRealtimeMessageType.ERROR:
error = data.get("error", {})
self._logger.error(f"OpenAI error: {error}")
continue
# Handle VAD events
if msg_type == OpenAIRealtimeMessageType.SPEECH_STARTED:
self._logger.info("OpenAI: Speech started")
# Reset current turn transcript for new speech
self._current_turn_transcript = ""
continue
elif msg_type == OpenAIRealtimeMessageType.SPEECH_STOPPED:
self._logger.info(
"OpenAI: Speech stopped (VAD detected silence)"
)
continue
elif msg_type == OpenAIRealtimeMessageType.BUFFER_COMMITTED:
self._logger.info("OpenAI: Audio buffer committed")
continue
# Handle transcription events
if msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA:
delta = data.get("delta", "")
if delta:
self._logger.info(f"OpenAI: Transcription delta: {delta}")
self._current_turn_transcript += delta
# Show accumulated + current turn transcript
full_transcript = self._accumulated_transcript
if full_transcript and self._current_turn_transcript:
full_transcript += " "
full_transcript += self._current_turn_transcript
await self._transcript_queue.put(
TranscriptResult(text=full_transcript, is_vad_end=False)
)
elif msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_COMPLETED:
transcript = data.get("transcript", "")
if transcript:
self._logger.info(
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
)
# This is the final transcript for this VAD turn
self._current_turn_transcript = transcript
# Accumulate this turn's transcript
if self._accumulated_transcript:
self._accumulated_transcript += " " + transcript
else:
self._accumulated_transcript = transcript
# Send with is_vad_end=True to trigger auto-send
await self._transcript_queue.put(
TranscriptResult(
text=self._accumulated_transcript,
is_vad_end=True,
)
)
elif msg_type not in (
OpenAIRealtimeMessageType.SESSION_CREATED,
OpenAIRealtimeMessageType.SESSION_UPDATED,
OpenAIRealtimeMessageType.ITEM_CREATED,
):
# Log any other message types we might be missing
self._logger.info(
f"OpenAI: Unhandled message type '{msg_type}': {data}"
)
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(f"WebSocket error: {self._ws.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
self._logger.info("WebSocket closed by server")
break
except Exception as e:
self._logger.error(f"Error in receive loop: {e}")
finally:
await self._transcript_queue.put(None)
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to OpenAI."""
if self._ws and not self._closed:
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
# So chunk_bytes / 48000 = duration in seconds
duration_ms = (len(chunk) / 48000) * 1000
self._logger.debug(
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
)
message = {
"type": "input_audio_buffer.append",
"audio": base64.b64encode(chunk).decode("utf-8"),
}
await self._ws.send_str(json.dumps(message))
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._logger.info("OpenAI: Resetting accumulated transcript")
self._accumulated_transcript = ""
self._current_turn_transcript = ""
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Close session and return final transcript."""
self._closed = True
if self._ws:
# With server VAD, the buffer is auto-committed when speech stops.
# But we should still commit any remaining audio and wait for transcription.
try:
await self._ws.send_str(
json.dumps({"type": "input_audio_buffer.commit"})
)
except Exception as e:
self._logger.debug(f"Error sending commit (may be expected): {e}")
# Wait for *new* transcription to arrive (up to 5 seconds)
self._logger.info("Waiting for transcription to complete...")
transcript_before_commit = self._accumulated_transcript
for _ in range(50): # 50 * 100ms = 5 seconds max
await asyncio.sleep(0.1)
if self._accumulated_transcript != transcript_before_commit:
self._logger.info(
f"Got final transcript: {self._accumulated_transcript[:50]}..."
)
break
else:
self._logger.warning("Timed out waiting for transcription")
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
return self._accumulated_transcript
# OpenAI available voices for TTS
OPENAI_VOICES = [
{"id": "alloy", "name": "Alloy"},
{"id": "echo", "name": "Echo"},
{"id": "fable", "name": "Fable"},
{"id": "onyx", "name": "Onyx"},
{"id": "nova", "name": "Nova"},
{"id": "shimmer", "name": "Shimmer"},
]
# OpenAI available STT models (all support streaming via Realtime API)
OPENAI_STT_MODELS = [
{"id": "whisper-1", "name": "Whisper v1"},
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
]
# OpenAI available TTS models
OPENAI_TTS_MODELS = [
{"id": "tts-1", "name": "TTS-1 (Standard)"},
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
]
def _create_wav_header(
data_length: int,
sample_rate: int = 24000,
channels: int = 1,
bits_per_sample: int = 16,
) -> bytes:
"""Create a WAV file header for PCM audio data."""
import struct
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
# WAV header is 44 bytes
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # ChunkID
36 + data_length, # ChunkSize
b"WAVE", # Format
b"fmt ", # Subchunk1ID
16, # Subchunk1Size (PCM)
1, # AudioFormat (1 = PCM)
channels, # NumChannels
sample_rate, # SampleRate
byte_rate, # ByteRate
block_align, # BlockAlign
bits_per_sample, # BitsPerSample
b"data", # Subchunk2ID
data_length, # Subchunk2Size
)
return header
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
def __init__(
self,
api_key: str,
voice: str = "alloy",
model: str = "tts-1",
speed: float = 1.0,
api_base: str | None = None,
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice = voice
self.model = model
self.speed = max(0.25, min(4.0, speed))
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
self._synthesis_task: asyncio.Task | None = None
self._closed = False
self._flushed = False
async def connect(self) -> None:
"""Initialize HTTP session for TTS requests."""
self._logger.info("OpenAIStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# Start background task to process text queue
self._synthesis_task = asyncio.create_task(self._process_text_queue())
self._logger.info("OpenAIStreamingSynthesizer: connected")
async def _process_text_queue(self) -> None:
"""Background task to process queued text for synthesis."""
while not self._closed:
try:
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
if text is None:
break
await self._synthesize_text(text)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
except Exception as e:
self._logger.error(f"Error processing text queue: {e}")
async def _synthesize_text(self, text: str) -> None:
"""Make HTTP TTS request and stream audio to queue."""
if not self._session or self._closed:
return
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"voice": self.voice,
"input": text,
"speed": self.speed,
"response_format": "mp3",
}
try:
async with self._session.post(
url, headers=headers, json=payload
) as response:
if response.status != 200:
error_text = await response.text()
self._logger.error(f"OpenAI TTS error: {error_text}")
return
# Use 8192 byte chunks for smoother streaming
# (larger chunks = more complete MP3 frames, better playback)
async for chunk in response.content.iter_chunked(8192):
if self._closed:
break
if chunk:
await self._audio_queue.put(chunk)
except Exception as e:
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
async def send_text(self, text: str) -> None:
"""Queue text to be synthesized via HTTP streaming."""
if not text.strip() or self._closed:
return
await self._text_queue.put(text)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk (MP3 format)."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for synthesis to complete."""
if self._flushed:
return
self._flushed = True
# Signal end of text input
await self._text_queue.put(None)
# Wait for synthesis task to complete processing all text
if self._synthesis_task and not self._synthesis_task.done():
try:
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
except asyncio.TimeoutError:
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
self._synthesis_task.cancel()
try:
await self._synthesis_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
# Signal end of audio stream
await self._audio_queue.put(None)
async def close(self) -> None:
"""Close the session."""
if self._closed:
return
self._closed = True
# Signal end of queues only if flush wasn't already called
if not self._flushed:
await self._text_queue.put(None)
await self._audio_queue.put(None)
if self._synthesis_task and not self._synthesis_task.done():
self._synthesis_task.cancel()
try:
await self._synthesis_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
class OpenAIVoiceProvider(VoiceProviderInterface):
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.stt_model = stt_model or "whisper-1"
self.tts_model = tts_model or "tts-1"
self.default_voice = default_voice or "alloy"
self._client: "AsyncOpenAI | None" = None
def _get_client(self) -> "AsyncOpenAI":
if self._client is None:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
return self._client
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using OpenAI Whisper.
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
client = self._get_client()
# Create a file-like object from the audio bytes
audio_file = io.BytesIO(audio_data)
audio_file.name = f"audio.{audio_format}"
response = await client.audio.transcriptions.create(
model=self.stt_model,
file=audio_file,
)
return response.text
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using OpenAI TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice identifier (defaults to provider's default voice)
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks (mp3 format)
"""
client = self._get_client()
# Clamp speed to valid range
speed = max(0.25, min(4.0, speed))
# Use with_streaming_response for proper async streaming
# Using 8192 byte chunks for better streaming performance
# (larger chunks = fewer round-trips, more complete MP3 frames)
async with client.audio.speech.with_streaming_response.create(
model=self.tts_model,
voice=voice or self.default_voice,
input=text,
speed=speed,
response_format="mp3",
) as response:
async for chunk in response.iter_bytes(chunk_size=8192):
yield chunk
async def validate_credentials(self) -> None:
"""Validate OpenAI API key by listing models."""
client = self._get_client()
await client.models.list()
def get_available_voices(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS voices."""
return OPENAI_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
"""Get available OpenAI STT models."""
return OPENAI_STT_MODELS.copy()
def get_available_tts_models(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS models."""
return OPENAI_TTS_MODELS.copy()
def supports_streaming_stt(self) -> bool:
"""OpenAI supports streaming via Realtime API for all STT models."""
return True
def supports_streaming_tts(self) -> bool:
"""OpenAI supports real-time streaming TTS via Realtime API."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> OpenAIStreamingTranscriber:
"""Create a streaming transcription session using Realtime API."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
transcriber = OpenAIStreamingTranscriber(
api_key=self.api_key,
model=self.stt_model,
api_base=self.api_base,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> OpenAIStreamingSynthesizer:
"""Create a streaming TTS session using HTTP streaming API."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
synthesizer = OpenAIStreamingSynthesizer(
api_key=self.api_key,
voice=voice or self.default_voice or "alloy",
model=self.tts_model or "tts-1",
speed=speed,
api_base=self.api_base,
)
await synthesizer.connect()
return synthesizer

View File

@@ -67,6 +67,8 @@ attrs==25.4.0
# zeep
authlib==1.6.7
# via fastmcp
azure-cognitiveservices-speech==1.38.0
# via onyx
babel==2.17.0
# via courlan
backoff==2.2.1
@@ -750,7 +752,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.7.5
pypdf==6.8.0
# via
# onyx
# unstructured-client

View File

@@ -406,7 +406,7 @@ referencing==0.36.2
# jsonschema-specifications
regex==2025.11.3
# via tiktoken
release-tag==0.4.3
release-tag==0.5.2
# via onyx
reorder-python-imports-black==3.14.0
# via onyx

View File

@@ -19,7 +19,7 @@ from fastapi.testclient import TestClient
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import UserRole
from onyx.main import fetch_versioned_implementation
from onyx.main import get_application
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -51,11 +51,8 @@ def client() -> Generator[TestClient, None, None]:
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
# CollectorRegistry" errors when multiple tests each create a new app
# (prometheus registers metrics globally and rejects duplicate names).
get_app = fetch_versioned_implementation(
module="onyx.main", attribute="get_application"
)
with patch("onyx.main.setup_prometheus_metrics"):
app: FastAPI = get_app(lifespan_override=test_lifespan)
app: FastAPI = get_application(lifespan_override=test_lifespan)
# Override the database session dependency with a mock
# (these tests don't actually need DB access)

View File

@@ -0,0 +1,398 @@
"""External dependency tests for the old DocumentIndex interface.
These tests assume Vespa and OpenSearch are running.
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
"""
import os
import time
import uuid
from collections.abc import Generator
from unittest.mock import patch
import httpx
import pytest
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.context.search.models import IndexFilters
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="module")
def opensearch_available() -> Generator[None, None, None]:
"""Verifies OpenSearch is running, fails the test if not."""
if not wait_for_opensearch_with_timeout():
pytest.fail("OpenSearch is not available.")
yield # Test runs here.
@pytest.fixture(scope="module")
def test_index_name() -> Generator[str, None, None]:
yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here.
@pytest.fixture(scope="module")
def tenant_context() -> Generator[None, None, None]:
"""Sets up tenant context for testing."""
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield # Test runs here.
finally:
# Reset the tenant context after the test
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture(scope="module")
def httpx_client() -> Generator[httpx.Client, None, None]:
client = get_vespa_http_client()
try:
yield client
finally:
client.close()
@pytest.fixture(scope="module")
def vespa_document_index(
httpx_client: httpx.Client,
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[VespaIndex, None, None]:
vespa_index = VespaIndex(
index_name=test_index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
backend_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..")
)
with patch("os.getcwd", return_value=backend_dir):
vespa_index.ensure_indices_exist(
primary_embedding_dim=128,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
# Verify Vespa is running, fails the test if not. Try 90 seconds for testing
# in CI. We have to do this here because this endpoint only becomes live
# once we create an index.
if not wait_for_vespa_with_timeout(wait_limit=90):
pytest.fail("Vespa is not available.")
# Wait until the schema is actually ready for writes on content nodes. We
# probe by attempting a PUT; 200 means the schema is live, 400 means not
# yet. This is so scuffed but running the test is really flakey otherwise;
# this is only temporary until we entirely move off of Vespa.
probe_doc = {
"fields": {
"document_id": "__probe__",
"chunk_id": 0,
"blurb": "",
"title": "",
"skip_title": True,
"content": "",
"content_summary": "",
"source_type": "file",
"source_links": "null",
"semantic_identifier": "",
"section_continuation": False,
"large_chunk_reference_ids": [],
"metadata": "{}",
"metadata_list": [],
"metadata_suffix": "",
"chunk_context": "",
"doc_summary": "",
"embeddings": {"full_chunk": [1.0] + [0.0] * 127},
"access_control_list": {},
"document_sets": {},
"image_file_name": None,
"user_project": [],
"personas": [],
"boost": 0.0,
"aggregated_chunk_boost_factor": 0.0,
"primary_owners": [],
"secondary_owners": [],
}
}
schema_ready = False
probe_url = (
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
)
for _ in range(60):
resp = httpx_client.post(probe_url, json=probe_doc)
if resp.status_code == 200:
schema_ready = True
# Clean up the probe document.
httpx_client.delete(probe_url)
break
time.sleep(1)
if not schema_ready:
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
yield vespa_index # Test runs here.
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
# pressing; in CI we should be using fresh instances of dependencies each
# time anyway.
@pytest.fixture(scope="module")
def opensearch_document_index(
opensearch_available: None, # noqa: ARG001
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
opensearch_index = OpenSearchOldDocumentIndex(
index_name=test_index_name,
embedding_dim=128,
embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_name=None,
secondary_embedding_dim=None,
secondary_embedding_precision=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
)
opensearch_index.ensure_indices_exist(
primary_embedding_dim=128,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
yield opensearch_index # Test runs here.
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
# pressing; in CI we should be using fresh instances of dependencies each
# time anyway.
@pytest.fixture(scope="module")
def document_indices(
vespa_document_index: VespaIndex,
opensearch_document_index: OpenSearchOldDocumentIndex,
) -> Generator[list[DocumentIndex], None, None]:
# Ideally these are parametrized; doing so with pytest fixtures is tricky.
yield [opensearch_document_index, vespa_document_index] # Test runs here.
@pytest.fixture(scope="function")
def chunks(
tenant_context: None, # noqa: ARG001
) -> Generator[list[DocMetadataAwareIndexChunk], None, None]:
result = []
chunk_count = 5
doc_id = "test_doc"
tenant_id = get_current_tenant_id()
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
document_sets: set[str] = set()
user_project: list[int] = list()
personas: list[int] = list()
boost = 0
blurb = "blurb"
content = "content"
title_prefix = ""
doc_summary = ""
chunk_context = ""
title_embedding = [1.0] + [0] * 127
# Full 0 vectors are not supported for cos similarity.
embeddings = ChunkEmbedding(
full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[]
)
source_document = Document(
id=doc_id,
semantic_identifier="semantic identifier",
source=DocumentSource.FILE,
sections=[],
metadata={},
title="title",
)
metadata_suffix_keyword = ""
image_file_id = None
source_links: dict[int, str] = {0: ""}
ancestor_hierarchy_node_ids: list[int] = []
for i in range(chunk_count):
result.append(
DocMetadataAwareIndexChunk(
tenant_id=tenant_id,
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=0,
ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids,
embeddings=embeddings,
title_embedding=title_embedding,
source_document=source_document,
title_prefix=title_prefix,
metadata_suffix_keyword=metadata_suffix_keyword,
metadata_suffix_semantic="",
contextual_rag_reserved_tokens=0,
doc_summary=doc_summary,
chunk_context=chunk_context,
mini_chunk_texts=None,
large_chunk_id=None,
chunk_id=i,
blurb=blurb,
content=content,
source_links=source_links,
image_file_id=image_file_id,
section_continuation=False,
)
)
yield result # Test runs here.
@pytest.fixture(scope="function")
def index_batch_params(
tenant_context: None, # noqa: ARG001
) -> Generator[IndexBatchParams, None, None]:
# WARNING: doc_id_to_previous_chunk_cnt={"test_doc": 0} is hardcoded to 0,
# which is only correct on the very first index call. The document_indices
# fixture is scope="module", meaning the same OpenSearch and Vespa backends
# persist across all test functions in this module. When a second test
# function uses this fixture and calls document_index.index(...), the
# backend already has 5 chunks for "test_doc" from the previous test run,
# but the batch params still claim 0 prior chunks exist. This can lead to
# orphaned/duplicate chunks that make subsequent assertions incorrect.
# TODO: Whenever adding a second test, either change this or cleanup the
# index between test cases.
yield IndexBatchParams(
doc_id_to_previous_chunk_cnt={"test_doc": 0},
doc_id_to_new_chunk_cnt={"test_doc": 5},
tenant_id=get_current_tenant_id(),
large_chunks_enabled=False,
)
class TestDocumentIndexOld:
"""Tests the old DocumentIndex interface."""
def test_update_single_can_clear_user_projects_and_personas(
self,
document_indices: list[DocumentIndex],
# This test case assumes all these chunks correspond to one document.
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> None:
"""
Tests that update_single can clear user_projects and personas.
"""
for document_index in document_indices:
# Precondition.
# Ensure there is some non-empty value for user project and
# personas.
for chunk in chunks:
chunk.user_project = [1]
chunk.personas = [2]
document_index.index(chunks, index_batch_params)
# Ensure that we can get chunks as expected with filters.
doc_id = chunks[0].source_document.id
chunk_count = len(chunks)
tenant_id = get_current_tenant_id()
# We need to specify the chunk index range and specify
# batch_retrieval=True below to trigger the codepath for Vespa's
# search API, which uses the expected additive filtering for
# project_id and persona_id. Otherwise we would use the codepath for
# the visit API, which does not have this kind of filtering
# implemented.
chunk_request = VespaChunkRequest(
document_id=doc_id, min_chunk_ind=0, max_chunk_ind=chunk_count - 1
)
project_persona_filters = IndexFilters(
access_control_list=None,
tenant_id=tenant_id,
project_id=1,
persona_id=2,
# We need this even though none of the chunks belong to a
# document set because project_id and persona_id are only
# additive filters in the event the agent has knowledge scope;
# if the agent does not, it is implied that it can see
# everything it is allowed to.
document_set=["1"],
)
# Not best practice here but the API for refreshing the index to
# ensure that the latest data is present is not exposed in this
# class and is not the same for Vespa and OpenSearch, so we just
# tolerate a sleep for now. As a consequence the number of tests in
# this suite should be small. We only need to tolerate this for as
# long as we continue to use Vespa, we can consider exposing
# something for OpenSearch later.
time.sleep(1)
inference_chunks = document_index.id_based_retrieval(
chunk_requests=[chunk_request],
filters=project_persona_filters,
batch_retrieval=True,
)
assert len(inference_chunks) == chunk_count
# Sort by chunk id to easily test if we have all chunks.
for i, inference_chunk in enumerate(
sorted(inference_chunks, key=lambda x: x.chunk_id)
):
assert inference_chunk.chunk_id == i
assert inference_chunk.document_id == doc_id
# Under test.
# Explicitly set empty fields here.
user_fields = VespaDocumentUserFields(user_projects=[], personas=[])
document_index.update_single(
doc_id=doc_id,
chunk_count=chunk_count,
tenant_id=tenant_id,
fields=None,
user_fields=user_fields,
)
# Postcondition.
filters = IndexFilters(access_control_list=None, tenant_id=tenant_id)
# We should expect to get back all expected chunks with no filters.
# Again, not best practice here.
time.sleep(1)
inference_chunks = document_index.id_based_retrieval(
chunk_requests=[chunk_request], filters=filters, batch_retrieval=True
)
assert len(inference_chunks) == chunk_count
# Sort by chunk id to easily test if we have all chunks.
for i, inference_chunk in enumerate(
sorted(inference_chunks, key=lambda x: x.chunk_id)
):
assert inference_chunk.chunk_id == i
assert inference_chunk.document_id == doc_id
# Now, we should expect to not get any chunks if we specify the user
# project and personas filters.
inference_chunks = document_index.id_based_retrieval(
chunk_requests=[chunk_request],
filters=project_persona_filters,
batch_retrieval=True,
)
assert len(inference_chunks) == 0

View File

@@ -17,6 +17,9 @@ from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.tasks import (
is_continuation_token_done_for_all_slices,
)
@@ -236,6 +239,8 @@ def full_deployment_setup() -> Generator[None, None, None]:
NOTE: We deliberately duplicate this logic from
backend/tests/external_dependency_unit/conftest.py because we need to set
opensearch_available just for this module, not the entire test session.
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
"""
# Patch ENABLE_OPENSEARCH_INDEXING_FOR_ONYX just for this test because we
# don't yet want that enabled for all tests.
@@ -320,9 +325,15 @@ def test_embedding_dimension(db_session: Session) -> Generator[int, None, None]:
@pytest.fixture(scope="function")
def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
test_page_size = 5
with patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
test_page_size,
with (
patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
test_page_size,
),
patch(
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
test_page_size,
),
):
yield test_page_size # Test runs here.
@@ -582,6 +593,175 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
document_chunks[document.id][opensearch_chunk.chunk_index],
)
def test_chunk_migration_visits_all_chunks_even_when_batch_size_varies(
self,
db_session: Session,
test_documents: list[Document],
vespa_document_index: VespaDocumentIndex,
opensearch_client: OpenSearchIndexClient,
test_embedding_dimension: int,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
) -> None:
"""
Tests that chunk migration works correctly even when the batch size
changes halfway through a migration.
Simulates task time running out my mocking the locking behavior.
"""
# Precondition.
# Index chunks into Vespa.
document_chunks: dict[str, list[dict[str, Any]]] = {
document.id: [
_create_raw_document_chunk(
document_id=document.id,
chunk_index=i,
content=f"Test content {i} for {document.id}",
embedding=_generate_test_vector(test_embedding_dimension),
now=datetime.now(),
title=f"Test title {document.id}",
title_embedding=_generate_test_vector(test_embedding_dimension),
)
for i in range(CHUNK_COUNT)
]
for document in test_documents
}
all_chunks: list[dict[str, Any]] = []
for chunks in document_chunks.values():
all_chunks.extend(chunks)
vespa_document_index.index_raw_chunks(all_chunks)
# Run the initial batch. To simulate partial progress we will mock the
# redis lock to return True for the first invocation of .owned() and
# False subsequently.
# NOTE: The batch size is currently set to 5 in
# patch_get_vespa_chunks_page_size.
mock_redis_client = Mock()
mock_lock = Mock()
mock_lock.owned.side_effect = [True, False, False]
mock_lock.acquire.return_value = True
mock_redis_client.lock.return_value = mock_lock
with patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client",
return_value=mock_redis_client,
):
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
)
assert result_1 is True
# Expire the session cache to see the committed changes from the task.
db_session.expire_all()
# Verify partial progress was saved.
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
assert tenant_record is not None
partial_chunks_migrated = tenant_record.total_chunks_migrated
assert partial_chunks_migrated > 0
# page_size applies per slice, so one iteration can fetch up to
# page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total.
assert partial_chunks_migrated <= 5 * GET_VESPA_CHUNKS_SLICE_COUNT
assert tenant_record.vespa_visit_continuation_token is not None
# Slices are not necessarily evenly distributed across all document
# chunks so we can't test that every token is non-None, but certainly at
# least one must be.
assert any(json.loads(tenant_record.vespa_visit_continuation_token).values())
assert tenant_record.migration_completed_at is None
assert tenant_record.approx_chunk_count_in_vespa is not None
# Under test.
# Now patch the batch size to be some other number, like 2.
mock_redis_client = Mock()
mock_lock = Mock()
mock_lock.owned.side_effect = [True, False, False]
mock_lock.acquire.return_value = True
mock_redis_client.lock.return_value = mock_lock
with (
patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
2,
),
patch(
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
2,
),
patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client",
return_value=mock_redis_client,
),
):
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
)
# Postcondition.
assert result_2 is True
# Expire the session cache to see the committed changes from the task.
db_session.expire_all()
# Verify next partial progress was saved.
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
assert tenant_record is not None
new_partial_chunks_migrated = tenant_record.total_chunks_migrated
assert new_partial_chunks_migrated > partial_chunks_migrated
# page_size applies per slice, so one iteration can fetch up to
# page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total.
assert new_partial_chunks_migrated <= (5 + 2) * GET_VESPA_CHUNKS_SLICE_COUNT
assert tenant_record.vespa_visit_continuation_token is not None
# Slices are not necessarily evenly distributed across all document
# chunks so we can't test that every token is non-None, but certainly at
# least one must be.
assert any(json.loads(tenant_record.vespa_visit_continuation_token).values())
assert tenant_record.migration_completed_at is None
assert tenant_record.approx_chunk_count_in_vespa is not None
# Under test.
# Run the remainder of the migration.
with (
patch(
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
2,
),
patch(
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
2,
),
):
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
tenant_id=get_current_tenant_id()
)
# Postcondition.
assert result_3 is True
# Expire the session cache to see the committed changes from the task.
db_session.expire_all()
# Verify completion.
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
assert tenant_record is not None
assert tenant_record.total_chunks_migrated > new_partial_chunks_migrated
assert tenant_record.total_chunks_migrated == len(all_chunks)
# Visit is complete so continuation token should be None.
assert tenant_record.vespa_visit_continuation_token is not None
assert is_continuation_token_done_for_all_slices(
json.loads(tenant_record.vespa_visit_continuation_token)
)
assert tenant_record.migration_completed_at is not None
assert tenant_record.approx_chunk_count_in_vespa == len(all_chunks)
# Verify chunks were indexed in OpenSearch.
for document in test_documents:
opensearch_chunks = _get_document_chunks_from_opensearch(
opensearch_client, document.id, get_current_tenant_id()
)
assert len(opensearch_chunks) == CHUNK_COUNT
opensearch_chunks.sort(key=lambda x: x.chunk_index)
for opensearch_chunk in opensearch_chunks:
_assert_chunk_matches_vespa_chunk(
opensearch_chunk,
document_chunks[document.id][opensearch_chunk.chunk_index],
)
def test_chunk_migration_empty_vespa(
self,
db_session: Session,

View File

@@ -6,6 +6,7 @@ Validates that:
- Crash + resume skips already-processed pages
- BFS (folder-scoped) drives process all items in one call
- 410 Gone triggers a full-resync URL in the checkpoint
- Duplicate document IDs across delta pages are deduplicated
"""
from __future__ import annotations
@@ -457,3 +458,228 @@ class TestDeltaPageFetchFailure:
assert final_cp.current_drive_name is None
assert final_cp.current_drive_id is None
assert final_cp.current_drive_delta_next_link is None
class TestDeltaDuplicateDocumentDedup:
"""The Microsoft Graph delta API can return the same item on multiple
pages. Documents already yielded should be skipped via
checkpoint.seen_document_ids."""
def test_duplicate_across_pages_is_skipped(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Item 'dup' appears on both page 1 and page 2. It should only be
yielded once."""
connector = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
call_count = 0
def fake_fetch_page(
self: SharepointConnector, # noqa: ARG001
page_url: str, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> tuple[list[DriveItemData], str | None]:
nonlocal call_count
call_count += 1
if call_count == 1:
return [_make_item("a"), _make_item("dup")], "https://next2"
return [_make_item("dup"), _make_item("b")], None
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
checkpoint = _build_ready_checkpoint()
# Page 1: yields a, dup
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
yielded, checkpoint = _consume_generator(gen)
docs = _docs_from(yielded)
assert [d.id for d in docs] == ["a", "dup"]
assert "dup" in checkpoint.seen_document_ids
# Page 2: dup should be skipped, only b yielded
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
yielded, checkpoint = _consume_generator(gen)
docs = _docs_from(yielded)
assert [d.id for d in docs] == ["b"]
def test_duplicate_within_same_page_is_skipped(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""If the same item appears twice on a single delta page, only the
first occurrence should be yielded."""
connector = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
def fake_fetch_page(
self: SharepointConnector, # noqa: ARG001
page_url: str, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> tuple[list[DriveItemData], str | None]:
return [_make_item("x"), _make_item("x"), _make_item("y")], None
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
checkpoint = _build_ready_checkpoint()
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
yielded, checkpoint = _consume_generator(gen)
docs = _docs_from(yielded)
assert [d.id for d in docs] == ["x", "y"]
def test_seen_ids_survive_checkpoint_serialization(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""seen_document_ids must survive JSON serialization so that
dedup works across crash + resume."""
connector = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
call_count = 0
def fake_fetch_page(
self: SharepointConnector, # noqa: ARG001
page_url: str, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> tuple[list[DriveItemData], str | None]:
nonlocal call_count
call_count += 1
if call_count == 1:
return [_make_item("a")], "https://next2"
return [_make_item("a"), _make_item("b")], None
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
checkpoint = _build_ready_checkpoint()
# Page 1
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
_, checkpoint = _consume_generator(gen)
assert "a" in checkpoint.seen_document_ids
# Simulate crash: round-trip through JSON
restored = SharepointConnectorCheckpoint.model_validate_json(
checkpoint.model_dump_json()
)
assert "a" in restored.seen_document_ids
# Page 2 with restored checkpoint: 'a' should be skipped
connector2 = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
gen = connector2._load_from_checkpoint(
_START_TS, _END_TS, restored, include_permissions=False
)
yielded, final_cp = _consume_generator(gen)
docs = _docs_from(yielded)
assert [d.id for d in docs] == ["b"]
def test_no_dedup_across_separate_indexing_runs(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""A fresh checkpoint (new indexing run) should have an empty
seen_document_ids, so previously-indexed docs are re-processed."""
connector = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
def fake_fetch_page(
self: SharepointConnector, # noqa: ARG001
page_url: str, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> tuple[list[DriveItemData], str | None]:
return [_make_item("a")], None
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
# First run
cp1 = _build_ready_checkpoint()
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, cp1, include_permissions=False
)
yielded, _ = _consume_generator(gen)
assert len(_docs_from(yielded)) == 1
# Second run with a fresh checkpoint — same doc should appear again
cp2 = _build_ready_checkpoint()
assert len(cp2.seen_document_ids) == 0
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, cp2, include_permissions=False
)
yielded, _ = _consume_generator(gen)
assert len(_docs_from(yielded)) == 1
def test_same_id_across_drives_not_skipped(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Graph item IDs are only unique within a drive. An item in drive B
that happens to share an ID with an item already seen in drive A must
NOT be skipped."""
connector = _setup_connector(monkeypatch)
_mock_convert(monkeypatch)
def fake_fetch_page(
self: SharepointConnector, # noqa: ARG001
page_url: str, # noqa: ARG001
drive_id: str, # noqa: ARG001
start: datetime | None = None, # noqa: ARG001
end: datetime | None = None, # noqa: ARG001
page_size: int = 200, # noqa: ARG001
) -> tuple[list[DriveItemData], str | None]:
return [_make_item("shared-id")], None
monkeypatch.setattr(
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
)
checkpoint = _build_ready_checkpoint(drive_names=["DriveA", "DriveB"])
# Drive A: yields the item
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
yielded, checkpoint = _consume_generator(gen)
docs = _docs_from(yielded)
assert len(docs) == 1
assert docs[0].id == "shared-id"
# seen_document_ids should have been cleared when drive A finished
assert len(checkpoint.seen_document_ids) == 0
# Drive B: same ID must be yielded again (different drive)
gen = connector._load_from_checkpoint(
_START_TS, _END_TS, checkpoint, include_permissions=False
)
yielded, checkpoint = _consume_generator(gen)
docs = _docs_from(yielded)
assert len(docs) == 1
assert docs[0].id == "shared-id"

View File

@@ -7,6 +7,7 @@ import pytest
from onyx.db.llm import sync_model_configurations
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import SyncModelEntry
class TestSyncModelConfigurations:
@@ -25,18 +26,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4",
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o",
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4",
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o",
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -67,18 +68,18 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Existing - should be skipped
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
{
"name": "gpt-4o", # New - should be inserted
"display_name": "GPT-4o",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Existing - should be skipped
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
SyncModelEntry(
name="gpt-4o", # New - should be inserted
display_name="GPT-4o",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -105,12 +106,12 @@ class TestSyncModelConfigurations:
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
models = [
{
"name": "gpt-4", # Already exists
"display_name": "GPT-4",
"max_input_tokens": 128000,
"supports_image_input": True,
},
SyncModelEntry(
name="gpt-4", # Already exists
display_name="GPT-4",
max_input_tokens=128000,
supports_image_input=True,
),
]
result = sync_model_configurations(
@@ -131,7 +132,7 @@ class TestSyncModelConfigurations:
sync_model_configurations(
db_session=mock_session,
provider_name="nonexistent",
models=[{"name": "model", "display_name": "Model"}],
models=[SyncModelEntry(name="model", display_name="Model")],
)
def test_handles_missing_optional_fields(self) -> None:
@@ -145,12 +146,12 @@ class TestSyncModelConfigurations:
with patch(
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
):
# Model with only required fields
# Model with only required fields (max_input_tokens and supports_image_input default)
models = [
{
"name": "model-1",
# No display_name, max_input_tokens, or supports_image_input
},
SyncModelEntry(
name="model-1",
display_name="Model 1",
),
]
result = sync_model_configurations(

View File

@@ -0,0 +1,507 @@
"""Unit tests for onyx.db.voice module."""
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from onyx.db.models import VoiceProvider
from onyx.db.voice import deactivate_stt_provider
from onyx.db.voice import deactivate_tts_provider
from onyx.db.voice import delete_voice_provider
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.db.voice import fetch_voice_provider_by_id
from onyx.db.voice import fetch_voice_provider_by_type
from onyx.db.voice import fetch_voice_providers
from onyx.db.voice import MAX_VOICE_PLAYBACK_SPEED
from onyx.db.voice import MIN_VOICE_PLAYBACK_SPEED
from onyx.db.voice import set_default_stt_provider
from onyx.db.voice import set_default_tts_provider
from onyx.db.voice import update_user_voice_settings
from onyx.db.voice import upsert_voice_provider
from onyx.error_handling.exceptions import OnyxError
def _make_voice_provider(
id: int = 1,
name: str = "Test Provider",
provider_type: str = "openai",
is_default_stt: bool = False,
is_default_tts: bool = False,
) -> VoiceProvider:
"""Create a VoiceProvider instance for testing."""
provider = VoiceProvider()
provider.id = id
provider.name = name
provider.provider_type = provider_type
provider.is_default_stt = is_default_stt
provider.is_default_tts = is_default_tts
provider.api_key = None
provider.api_base = None
provider.custom_config = None
provider.stt_model = None
provider.tts_model = None
provider.default_voice = None
return provider
class TestFetchVoiceProviders:
"""Tests for fetch_voice_providers."""
def test_returns_all_providers(self, mock_db_session: MagicMock) -> None:
providers = [
_make_voice_provider(id=1, name="Provider A"),
_make_voice_provider(id=2, name="Provider B"),
]
mock_db_session.scalars.return_value.all.return_value = providers
result = fetch_voice_providers(mock_db_session)
assert result == providers
mock_db_session.scalars.assert_called_once()
def test_returns_empty_list_when_no_providers(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalars.return_value.all.return_value = []
result = fetch_voice_providers(mock_db_session)
assert result == []
class TestFetchVoiceProviderById:
"""Tests for fetch_voice_provider_by_id."""
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
result = fetch_voice_provider_by_id(mock_db_session, 1)
assert result is provider
mock_db_session.scalar.assert_called_once()
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
mock_db_session.scalar.return_value = None
result = fetch_voice_provider_by_id(mock_db_session, 999)
assert result is None
class TestFetchDefaultProviders:
"""Tests for fetch_default_stt_provider and fetch_default_tts_provider."""
def test_fetch_default_stt_provider_returns_provider(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1, is_default_stt=True)
mock_db_session.scalar.return_value = provider
result = fetch_default_stt_provider(mock_db_session)
assert result is provider
def test_fetch_default_stt_provider_returns_none_when_no_default(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
result = fetch_default_stt_provider(mock_db_session)
assert result is None
def test_fetch_default_tts_provider_returns_provider(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1, is_default_tts=True)
mock_db_session.scalar.return_value = provider
result = fetch_default_tts_provider(mock_db_session)
assert result is provider
def test_fetch_default_tts_provider_returns_none_when_no_default(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
result = fetch_default_tts_provider(mock_db_session)
assert result is None
class TestFetchVoiceProviderByType:
"""Tests for fetch_voice_provider_by_type."""
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1, provider_type="openai")
mock_db_session.scalar.return_value = provider
result = fetch_voice_provider_by_type(mock_db_session, "openai")
assert result is provider
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
mock_db_session.scalar.return_value = None
result = fetch_voice_provider_by_type(mock_db_session, "nonexistent")
assert result is None
class TestUpsertVoiceProvider:
"""Tests for upsert_voice_provider."""
def test_creates_new_provider_when_no_id(self, mock_db_session: MagicMock) -> None:
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
upsert_voice_provider(
db_session=mock_db_session,
provider_id=None,
name="New Provider",
provider_type="openai",
api_key="test-key",
api_key_changed=True,
)
mock_db_session.add.assert_called_once()
mock_db_session.flush.assert_called()
added_obj = mock_db_session.add.call_args[0][0]
assert added_obj.name == "New Provider"
assert added_obj.provider_type == "openai"
def test_updates_existing_provider(self, mock_db_session: MagicMock) -> None:
existing_provider = _make_voice_provider(id=1, name="Old Name")
mock_db_session.scalar.return_value = existing_provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
upsert_voice_provider(
db_session=mock_db_session,
provider_id=1,
name="Updated Name",
provider_type="elevenlabs",
api_key="new-key",
api_key_changed=True,
)
mock_db_session.add.assert_not_called()
assert existing_provider.name == "Updated Name"
assert existing_provider.provider_type == "elevenlabs"
def test_raises_when_provider_not_found(self, mock_db_session: MagicMock) -> None:
mock_db_session.scalar.return_value = None
with pytest.raises(OnyxError) as exc_info:
upsert_voice_provider(
db_session=mock_db_session,
provider_id=999,
name="Test",
provider_type="openai",
api_key=None,
api_key_changed=False,
)
assert "No voice provider with id 999" in str(exc_info.value)
def test_does_not_update_api_key_when_not_changed(
self, mock_db_session: MagicMock
) -> None:
existing_provider = _make_voice_provider(id=1)
existing_provider.api_key = "original-key" # type: ignore[assignment]
original_api_key = existing_provider.api_key
mock_db_session.scalar.return_value = existing_provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
upsert_voice_provider(
db_session=mock_db_session,
provider_id=1,
name="Test",
provider_type="openai",
api_key="new-key",
api_key_changed=False,
)
# api_key should remain unchanged (same object reference)
assert existing_provider.api_key is original_api_key
def test_activates_stt_when_requested(self, mock_db_session: MagicMock) -> None:
existing_provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = existing_provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
mock_db_session.execute.return_value = None
upsert_voice_provider(
db_session=mock_db_session,
provider_id=1,
name="Test",
provider_type="openai",
api_key=None,
api_key_changed=False,
activate_stt=True,
)
assert existing_provider.is_default_stt is True
def test_activates_tts_when_requested(self, mock_db_session: MagicMock) -> None:
existing_provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = existing_provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
mock_db_session.execute.return_value = None
upsert_voice_provider(
db_session=mock_db_session,
provider_id=1,
name="Test",
provider_type="openai",
api_key=None,
api_key_changed=False,
activate_tts=True,
)
assert existing_provider.is_default_tts is True
class TestDeleteVoiceProvider:
"""Tests for delete_voice_provider."""
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
delete_voice_provider(mock_db_session, 1)
assert provider.deleted is True
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_provider_not_found(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
delete_voice_provider(mock_db_session, 999)
mock_db_session.flush.assert_not_called()
class TestSetDefaultProviders:
"""Tests for set_default_stt_provider and set_default_tts_provider."""
def test_set_default_stt_provider_deactivates_others(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
mock_db_session.execute.return_value = None
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
result = set_default_stt_provider(db_session=mock_db_session, provider_id=1)
mock_db_session.execute.assert_called_once()
assert result.is_default_stt is True
def test_set_default_stt_provider_raises_when_not_found(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
with pytest.raises(OnyxError) as exc_info:
set_default_stt_provider(db_session=mock_db_session, provider_id=999)
assert "No voice provider with id 999" in str(exc_info.value)
def test_set_default_tts_provider_deactivates_others(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
mock_db_session.execute.return_value = None
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
result = set_default_tts_provider(db_session=mock_db_session, provider_id=1)
mock_db_session.execute.assert_called_once()
assert result.is_default_tts is True
def test_set_default_tts_provider_updates_model_when_provided(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1)
mock_db_session.scalar.return_value = provider
mock_db_session.execute.return_value = None
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
result = set_default_tts_provider(
db_session=mock_db_session, provider_id=1, tts_model="tts-1-hd"
)
assert result.tts_model == "tts-1-hd"
def test_set_default_tts_provider_raises_when_not_found(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
with pytest.raises(OnyxError) as exc_info:
set_default_tts_provider(db_session=mock_db_session, provider_id=999)
assert "No voice provider with id 999" in str(exc_info.value)
class TestDeactivateProviders:
"""Tests for deactivate_stt_provider and deactivate_tts_provider."""
def test_deactivate_stt_provider_sets_false(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1, is_default_stt=True)
mock_db_session.scalar.return_value = provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
result = deactivate_stt_provider(db_session=mock_db_session, provider_id=1)
assert result.is_default_stt is False
def test_deactivate_stt_provider_raises_when_not_found(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
with pytest.raises(OnyxError) as exc_info:
deactivate_stt_provider(db_session=mock_db_session, provider_id=999)
assert "No voice provider with id 999" in str(exc_info.value)
def test_deactivate_tts_provider_sets_false(
self, mock_db_session: MagicMock
) -> None:
provider = _make_voice_provider(id=1, is_default_tts=True)
mock_db_session.scalar.return_value = provider
mock_db_session.flush.return_value = None
mock_db_session.refresh.return_value = None
result = deactivate_tts_provider(db_session=mock_db_session, provider_id=1)
assert result.is_default_tts is False
def test_deactivate_tts_provider_raises_when_not_found(
self, mock_db_session: MagicMock
) -> None:
mock_db_session.scalar.return_value = None
with pytest.raises(OnyxError) as exc_info:
deactivate_tts_provider(db_session=mock_db_session, provider_id=999)
assert "No voice provider with id 999" in str(exc_info.value)
class TestUpdateUserVoiceSettings:
"""Tests for update_user_voice_settings."""
def test_updates_auto_send(self, mock_db_session: MagicMock) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id, auto_send=True)
mock_db_session.execute.assert_called_once()
mock_db_session.flush.assert_called_once()
def test_updates_auto_playback(self, mock_db_session: MagicMock) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id, auto_playback=True)
mock_db_session.execute.assert_called_once()
mock_db_session.flush.assert_called_once()
def test_updates_playback_speed_within_range(
self, mock_db_session: MagicMock
) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id, playback_speed=1.5)
mock_db_session.execute.assert_called_once()
def test_clamps_playback_speed_to_min(self, mock_db_session: MagicMock) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id, playback_speed=0.1)
mock_db_session.execute.assert_called_once()
stmt = mock_db_session.execute.call_args[0][0]
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
assert str(MIN_VOICE_PLAYBACK_SPEED) in str(compiled)
def test_clamps_playback_speed_to_max(self, mock_db_session: MagicMock) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id, playback_speed=5.0)
mock_db_session.execute.assert_called_once()
stmt = mock_db_session.execute.call_args[0][0]
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
assert str(MAX_VOICE_PLAYBACK_SPEED) in str(compiled)
def test_updates_multiple_settings(self, mock_db_session: MagicMock) -> None:
user_id = uuid4()
update_user_voice_settings(
mock_db_session,
user_id,
auto_send=True,
auto_playback=False,
playback_speed=1.25,
)
mock_db_session.execute.assert_called_once()
mock_db_session.flush.assert_called_once()
def test_does_nothing_when_no_settings_provided(
self, mock_db_session: MagicMock
) -> None:
user_id = uuid4()
update_user_voice_settings(mock_db_session, user_id)
mock_db_session.execute.assert_not_called()
mock_db_session.flush.assert_not_called()
class TestSpeedClampingLogic:
"""Tests for the speed clamping constants and logic."""
def test_min_speed_constant(self) -> None:
assert MIN_VOICE_PLAYBACK_SPEED == 0.5
def test_max_speed_constant(self) -> None:
assert MAX_VOICE_PLAYBACK_SPEED == 2.0
def test_clamping_formula(self) -> None:
"""Verify the clamping formula used in update_user_voice_settings."""
test_cases = [
(0.1, MIN_VOICE_PLAYBACK_SPEED),
(0.5, 0.5),
(1.0, 1.0),
(1.5, 1.5),
(2.0, 2.0),
(3.0, MAX_VOICE_PLAYBACK_SPEED),
]
for speed, expected in test_cases:
clamped = max(
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, speed)
)
assert (
clamped == expected
), f"speed={speed} expected={expected} got={clamped}"

View File

@@ -1,196 +0,0 @@
import io
import openpyxl
from onyx.file_processing.extract_file_text import xlsx_to_text
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
wb = openpyxl.Workbook()
if wb.active is not None:
wb.remove(wb.active)
for sheet_name, rows in sheets.items():
ws = wb.create_sheet(title=sheet_name)
for row in rows:
ws.append(row)
buf = io.BytesIO()
wb.save(buf)
buf.seek(0)
return buf
class TestXlsxToText:
def test_single_sheet_basic(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["Name", "Age"],
["Alice", "30"],
["Bob", "25"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "Name" in lines[0]
assert "Age" in lines[0]
assert "Alice" in lines[1]
assert "30" in lines[1]
assert "Bob" in lines[2]
def test_multiple_sheets_separated(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [["a", "b"]],
"Sheet2": [["c", "d"]],
}
)
result = xlsx_to_text(xlsx)
# TEXT_SECTION_SEPARATOR is "\n\n"
assert "\n\n" in result
parts = result.split("\n\n")
assert any("a" in p for p in parts)
assert any("c" in p for p in parts)
def test_empty_cells(self) -> None:
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "b"],
["", "c", ""],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
def test_commas_in_cells_are_quoted(self) -> None:
"""Cells containing commas should be quoted in CSV output."""
xlsx = _make_xlsx(
{
"Sheet1": [
["hello, world", "normal"],
]
}
)
result = xlsx_to_text(xlsx)
assert '"hello, world"' in result
def test_empty_workbook(self) -> None:
xlsx = _make_xlsx({"Sheet1": []})
result = xlsx_to_text(xlsx)
assert result.strip() == ""
def test_long_empty_row_run_capped(self) -> None:
"""Runs of >2 empty rows should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["header"],
[""],
[""],
[""],
[""],
["data"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
assert len(lines) == 4
assert "header" in lines[0]
assert "data" in lines[-1]
def test_long_empty_col_run_capped(self) -> None:
"""Runs of >2 empty columns should be capped to 2."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "", "", "", "b"],
["c", "", "", "", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 2
# Each row should have 4 fields (a + 2 empty + b), not 5
# csv format: a,,,b (3 commas = 4 fields)
first_line = lines[0].strip()
# Count commas to verify column reduction
assert first_line.count(",") == 3
def test_short_empty_runs_kept(self) -> None:
"""Runs of <=2 empty rows/cols should be preserved."""
xlsx = _make_xlsx(
{
"Sheet1": [
["a", "b"],
["", ""],
["", ""],
["c", "d"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# All 4 rows preserved (2 empty rows <= threshold)
assert len(lines) == 4
def test_bad_zip_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="test.xlsx")
assert result == ""
def test_bad_zip_tilde_file_returns_empty(self) -> None:
bad_file = io.BytesIO(b"not a zip file")
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
assert result == ""
def test_large_sparse_sheet(self) -> None:
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
rows: list[list[str]] = [["row1_data"]]
rows.extend([[""] for _ in range(10)])
rows.append(["row2_data"])
xlsx = _make_xlsx({"Sheet1": rows})
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
assert len(lines) == 4
assert "row1_data" in lines[0]
assert "row2_data" in lines[-1]
def test_quotes_in_cells(self) -> None:
"""Cells containing quotes should be properly escaped."""
xlsx = _make_xlsx(
{
"Sheet1": [
['say "hello"', "normal"],
]
}
)
result = xlsx_to_text(xlsx)
# csv.writer escapes quotes by doubling them
assert '""hello""' in result
def test_each_row_is_separate_line(self) -> None:
"""Each row should produce its own line (regression for writerow vs writerows)."""
xlsx = _make_xlsx(
{
"Sheet1": [
["r1c1", "r1c2"],
["r2c1", "r2c2"],
["r3c1", "r3c2"],
]
}
)
result = xlsx_to_text(xlsx)
lines = [line for line in result.strip().split("\n") if line.strip()]
assert len(lines) == 3
assert "r1c1" in lines[0] and "r1c2" in lines[0]
assert "r2c1" in lines[1] and "r2c2" in lines[1]
assert "r3c1" in lines[2] and "r3c2" in lines[2]

View File

@@ -26,14 +26,6 @@ class TestIsTrueOpenAIModel:
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
def test_real_openai_o1_preview(self) -> None:
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
def test_real_openai_o1_mini(self) -> None:
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
def test_openai_with_provider_prefix(self) -> None:
"""Test that OpenAI model with provider prefix is correctly identified."""
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False

View File

@@ -1,15 +1,19 @@
"""Tests for LLM model fetch endpoints.
These tests verify the full request/response flow for fetching models
from dynamic providers (Ollama, OpenRouter), including the
from dynamic providers (Ollama, OpenRouter, Litellm), including the
sync-to-DB behavior when provider_name is specified.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
import httpx
import pytest
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.llm.models import LitellmFinalModelResponse
from onyx.server.manage.llm.models import LitellmModelsRequest
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
from onyx.server.manage.llm.models import LMStudioModelsRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
@@ -614,3 +618,283 @@ class TestGetLMStudioAvailableModels:
request = LMStudioModelsRequest(api_base="http://localhost:1234")
with pytest.raises(OnyxError):
get_lm_studio_available_models(request, MagicMock(), mock_session)
class TestGetLitellmAvailableModels:
"""Tests for the Litellm proxy model fetch endpoint."""
@pytest.fixture
def mock_litellm_response(self) -> dict:
"""Mock response from Litellm /v1/models endpoint."""
return {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
{
"id": "claude-3-5-sonnet",
"object": "model",
"created": 1700000001,
"owned_by": "anthropic",
},
{
"id": "gemini-pro",
"object": "model",
"created": 1700000002,
"owned_by": "google",
},
]
}
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
"""Test that endpoint returns properly formatted model list."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 3
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
"""Test that provider_name and model_name are correctly extracted."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
gpt = next(r for r in results if r.model_name == "gpt-4o")
assert gpt.provider_name == "openai"
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
assert claude.provider_name == "anthropic"
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
"""Test that results are alphabetically sorted by model_name."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
model_names = [r.model_name for r in results]
assert model_names == sorted(model_names, key=str.lower)
def test_empty_data_raises_onyx_error(self) -> None:
"""Test that empty model list raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {"data": []}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No models found"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_missing_data_key_raises_onyx_error(self) -> None:
"""Test that response without 'data' key raises OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = {}
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_skips_unparseable_entries(self) -> None:
"""Test that malformed model entries are skipped without failing."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_with_bad_entry = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
# Missing required fields
{"bad_field": "bad_value"},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_with_bad_entry
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
results = get_litellm_available_models(request, MagicMock(), mock_session)
assert len(results) == 1
assert results[0].model_name == "gpt-4o"
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
"""Test that OnyxError is raised when all entries fail to parse."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
response_all_bad = {
"data": [
{"bad_field": "bad_value"},
{"another_bad": 123},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = response_all_bad
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="No compatible models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_api_base_trailing_slash_handled(self) -> None:
"""Test that trailing slashes in api_base are handled correctly."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
mock_litellm_response = {
"data": [
{
"id": "gpt-4o",
"object": "model",
"created": 1700000000,
"owned_by": "openai",
},
]
}
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.json.return_value = mock_litellm_response
mock_response.raise_for_status = MagicMock()
mock_get.return_value = mock_response
request = LitellmModelsRequest(
api_base="http://localhost:4000/",
api_key="test-key",
)
get_litellm_available_models(request, MagicMock(), mock_session)
# Should call /v1/models without double slashes
call_args = mock_get.call_args
assert call_args[0][0] == "http://localhost:4000/v1/models"
def test_connection_failure_raises_onyx_error(self) -> None:
"""Test that connection failures are wrapped in OnyxError."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_get.side_effect = Exception("Connection refused")
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_401_raises_authentication_error(self) -> None:
"""Test that a 401 response raises OnyxError with authentication message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 401
mock_get.side_effect = httpx.HTTPStatusError(
"Unauthorized", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="bad-key",
)
with pytest.raises(OnyxError, match="Authentication failed"):
get_litellm_available_models(request, MagicMock(), mock_session)
def test_404_raises_not_found_error(self) -> None:
"""Test that a 404 response raises OnyxError with endpoint not found message."""
from onyx.server.manage.llm.api import get_litellm_available_models
mock_session = MagicMock()
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
mock_response = MagicMock()
mock_response.status_code = 404
mock_get.side_effect = httpx.HTTPStatusError(
"Not Found", request=MagicMock(), response=mock_response
)
request = LitellmModelsRequest(
api_base="http://localhost:4000",
api_key="test-key",
)
with pytest.raises(OnyxError, match="endpoint not found"):
get_litellm_available_models(request, MagicMock(), mock_session)

View File

@@ -0,0 +1,23 @@
import pytest
from onyx.error_handling.exceptions import OnyxError
from onyx.server.manage.voice.api import _validate_voice_api_base
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
with pytest.raises(OnyxError, match="Invalid target URI"):
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
def test_validate_voice_api_base_allows_private_for_azure() -> None:
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
assert validated == "http://127.0.0.1:5000"
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
with pytest.raises(OnyxError, match="Invalid target URI"):
_validate_voice_api_base("azure", "http://metadata.google.internal/")
def test_validate_voice_api_base_returns_none_for_none() -> None:
assert _validate_voice_api_base("openai", None) is None

View File

@@ -14,6 +14,7 @@ from onyx.utils.url import _is_ip_private_or_reserved
from onyx.utils.url import _validate_and_resolve_url
from onyx.utils.url import ssrf_safe_get
from onyx.utils.url import SSRFException
from onyx.utils.url import validate_outbound_http_url
class TestIsIpPrivateOrReserved:
@@ -305,3 +306,22 @@ class TestSsrfSafeGet:
call_args = mock_get.call_args
assert call_args[1]["timeout"] == (5, 15)
class TestValidateOutboundHttpUrl:
def test_rejects_private_ip_by_default(self) -> None:
with pytest.raises(SSRFException, match="internal/private IP"):
validate_outbound_http_url("http://127.0.0.1:8000")
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
validated_url = validate_outbound_http_url(
"http://127.0.0.1:8000", allow_private_network=True
)
assert validated_url == "http://127.0.0.1:8000"
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
with pytest.raises(SSRFException, match="not allowed"):
validate_outbound_http_url(
"http://metadata.google.internal/latest",
allow_private_network=True,
)

View File

@@ -0,0 +1,30 @@
import pytest
from onyx.voice.providers.azure import AzureVoiceProvider
def test_azure_provider_extracts_region_from_target_uri() -> None:
provider = AzureVoiceProvider(
api_key="key",
api_base="https://westus.api.cognitive.microsoft.com/",
custom_config={},
)
assert provider.speech_region == "westus"
def test_azure_provider_normalizes_uppercase_region() -> None:
provider = AzureVoiceProvider(
api_key="key",
api_base=None,
custom_config={"speech_region": "WestUS2"},
)
assert provider.speech_region == "westus2"
def test_azure_provider_rejects_invalid_speech_region() -> None:
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
AzureVoiceProvider(
api_key="key",
api_base=None,
custom_config={"speech_region": "westus/../../etc"},
)

View File

@@ -0,0 +1,194 @@
import io
import struct
import wave
import pytest
from onyx.voice.providers.azure import AzureVoiceProvider
# --- _is_azure_cloud_url ---
def test_is_azure_cloud_url_speech_microsoft() -> None:
assert AzureVoiceProvider._is_azure_cloud_url(
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
)
def test_is_azure_cloud_url_cognitive_microsoft() -> None:
assert AzureVoiceProvider._is_azure_cloud_url(
"https://westus.api.cognitive.microsoft.com/"
)
def test_is_azure_cloud_url_rejects_custom_host() -> None:
assert not AzureVoiceProvider._is_azure_cloud_url("https://my-custom-host.com/")
def test_is_azure_cloud_url_rejects_none() -> None:
assert not AzureVoiceProvider._is_azure_cloud_url(None)
# --- _extract_speech_region_from_uri ---
def test_extract_region_from_tts_url() -> None:
assert (
AzureVoiceProvider._extract_speech_region_from_uri(
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
)
== "eastus"
)
def test_extract_region_from_cognitive_api_url() -> None:
assert (
AzureVoiceProvider._extract_speech_region_from_uri(
"https://eastus.api.cognitive.microsoft.com/"
)
== "eastus"
)
def test_extract_region_returns_none_for_custom_domain() -> None:
"""Custom domains use resource name, not region — must use speech_region config."""
assert (
AzureVoiceProvider._extract_speech_region_from_uri(
"https://myresource.cognitiveservices.azure.com/"
)
is None
)
def test_extract_region_returns_none_for_none() -> None:
assert AzureVoiceProvider._extract_speech_region_from_uri(None) is None
# --- _validate_speech_region ---
def test_validate_region_normalizes_to_lowercase() -> None:
assert AzureVoiceProvider._validate_speech_region("WestUS2") == "westus2"
def test_validate_region_accepts_hyphens() -> None:
assert AzureVoiceProvider._validate_speech_region("us-east-1") == "us-east-1"
def test_validate_region_rejects_path_traversal() -> None:
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
AzureVoiceProvider._validate_speech_region("westus/../../etc")
def test_validate_region_rejects_dots() -> None:
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
AzureVoiceProvider._validate_speech_region("west.us")
# --- _pcm16_to_wav ---
def test_pcm16_to_wav_produces_valid_wav() -> None:
samples = [32767, -32768, 0, 1234]
pcm_data = struct.pack(f"<{len(samples)}h", *samples)
wav_bytes = AzureVoiceProvider._pcm16_to_wav(pcm_data, sample_rate=16000)
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
assert wav_file.getnchannels() == 1
assert wav_file.getsampwidth() == 2
assert wav_file.getframerate() == 16000
frames = wav_file.readframes(4)
recovered = struct.unpack(f"<{len(samples)}h", frames)
assert list(recovered) == samples
# --- URL Construction ---
def test_get_tts_url_cloud() -> None:
provider = AzureVoiceProvider(
api_key="key", api_base=None, custom_config={"speech_region": "eastus"}
)
assert (
provider._get_tts_url()
== "https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
)
def test_get_stt_url_cloud() -> None:
provider = AzureVoiceProvider(
api_key="key", api_base=None, custom_config={"speech_region": "westus2"}
)
assert "westus2.stt.speech.microsoft.com" in provider._get_stt_url()
def test_get_tts_url_self_hosted() -> None:
provider = AzureVoiceProvider(
api_key="key", api_base="http://localhost:5000", custom_config={}
)
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
def test_get_tts_url_self_hosted_strips_trailing_slash() -> None:
provider = AzureVoiceProvider(
api_key="key", api_base="http://localhost:5000/", custom_config={}
)
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
# --- _is_self_hosted ---
def test_is_self_hosted_true_for_custom_endpoint() -> None:
provider = AzureVoiceProvider(
api_key="key", api_base="http://localhost:5000", custom_config={}
)
assert provider._is_self_hosted() is True
def test_is_self_hosted_false_for_azure_cloud() -> None:
provider = AzureVoiceProvider(
api_key="key",
api_base="https://eastus.api.cognitive.microsoft.com/",
custom_config={},
)
assert provider._is_self_hosted() is False
# --- Resampling ---
def test_resample_pcm16_passthrough() -> None:
from onyx.voice.providers.azure import AzureStreamingTranscriber
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
t.input_sample_rate = 16000
t.target_sample_rate = 16000
data = struct.pack("<4h", 100, 200, 300, 400)
assert t._resample_pcm16(data) == data
def test_resample_pcm16_downsamples() -> None:
from onyx.voice.providers.azure import AzureStreamingTranscriber
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
t.input_sample_rate = 24000
t.target_sample_rate = 16000
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
result = t._resample_pcm16(data)
assert len(result) // 2 == 4
def test_resample_pcm16_empty_data() -> None:
from onyx.voice.providers.azure import AzureStreamingTranscriber
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
t.input_sample_rate = 24000
t.target_sample_rate = 16000
assert t._resample_pcm16(b"") == b""

View File

@@ -0,0 +1,117 @@
import struct
from onyx.voice.providers.elevenlabs import _http_to_ws_url
from onyx.voice.providers.elevenlabs import DEFAULT_ELEVENLABS_API_BASE
from onyx.voice.providers.elevenlabs import ElevenLabsSTTMessageType
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
# --- _http_to_ws_url ---
def test_http_to_ws_url_converts_https_to_wss() -> None:
assert _http_to_ws_url("https://api.elevenlabs.io") == "wss://api.elevenlabs.io"
def test_http_to_ws_url_converts_http_to_ws() -> None:
assert _http_to_ws_url("http://localhost:8080") == "ws://localhost:8080"
def test_http_to_ws_url_passes_through_other_schemes() -> None:
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
def test_http_to_ws_url_preserves_path() -> None:
assert (
_http_to_ws_url("https://api.elevenlabs.io/v1/tts")
== "wss://api.elevenlabs.io/v1/tts"
)
# --- StrEnum comparison ---
def test_stt_message_type_compares_as_string() -> None:
"""StrEnum members should work in string comparisons (e.g. from JSON)."""
assert str(ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT) == "committed_transcript"
assert isinstance(ElevenLabsSTTMessageType.ERROR, str)
# --- Resampling ---
def test_resample_pcm16_passthrough_when_same_rate() -> None:
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
t.input_sample_rate = 16000
t.target_sample_rate = 16000
data = struct.pack("<4h", 100, 200, 300, 400)
assert t._resample_pcm16(data) == data
def test_resample_pcm16_downsamples() -> None:
"""24kHz -> 16kHz should produce fewer samples (ratio 3:2)."""
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
t.input_sample_rate = 24000
t.target_sample_rate = 16000
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
result = t._resample_pcm16(data)
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
assert len(output_samples) == 4
def test_resample_pcm16_clamps_to_int16_range() -> None:
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
t.input_sample_rate = 24000
t.target_sample_rate = 16000
input_samples = [32767, -32768, 32767, -32768, 32767, -32768]
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
result = t._resample_pcm16(data)
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
for s in output_samples:
assert -32768 <= s <= 32767
# --- Provider Model Defaulting ---
def test_provider_defaults_invalid_stt_model() -> None:
provider = ElevenLabsVoiceProvider(api_key="test", stt_model="invalid_model")
assert provider.stt_model == "scribe_v1"
def test_provider_defaults_invalid_tts_model() -> None:
provider = ElevenLabsVoiceProvider(api_key="test", tts_model="invalid_model")
assert provider.tts_model == "eleven_multilingual_v2"
def test_provider_accepts_valid_models() -> None:
provider = ElevenLabsVoiceProvider(
api_key="test", stt_model="scribe_v2_realtime", tts_model="eleven_turbo_v2_5"
)
assert provider.stt_model == "scribe_v2_realtime"
assert provider.tts_model == "eleven_turbo_v2_5"
def test_provider_defaults_api_base() -> None:
provider = ElevenLabsVoiceProvider(api_key="test")
assert provider.api_base == DEFAULT_ELEVENLABS_API_BASE
def test_provider_get_available_voices_returns_copy() -> None:
provider = ElevenLabsVoiceProvider(api_key="test")
voices = provider.get_available_voices()
voices.clear()
assert len(provider.get_available_voices()) > 0

View File

@@ -0,0 +1,97 @@
import io
import struct
import wave
from onyx.voice.providers.openai import _create_wav_header
from onyx.voice.providers.openai import _http_to_ws_url
from onyx.voice.providers.openai import OpenAIRealtimeMessageType
from onyx.voice.providers.openai import OpenAIVoiceProvider
# --- _http_to_ws_url ---
def test_http_to_ws_url_converts_https_to_wss() -> None:
assert _http_to_ws_url("https://api.openai.com") == "wss://api.openai.com"
def test_http_to_ws_url_converts_http_to_ws() -> None:
assert _http_to_ws_url("http://localhost:9090") == "ws://localhost:9090"
def test_http_to_ws_url_passes_through_ws() -> None:
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
# --- StrEnum comparison ---
def test_realtime_message_type_compares_as_string() -> None:
assert str(OpenAIRealtimeMessageType.ERROR) == "error"
assert (
str(OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA)
== "conversation.item.input_audio_transcription.delta"
)
assert isinstance(OpenAIRealtimeMessageType.ERROR, str)
# --- _create_wav_header ---
def test_wav_header_is_44_bytes() -> None:
assert len(_create_wav_header(1000)) == 44
def test_wav_header_chunk_size_matches_data_length() -> None:
data_length = 2000
header = _create_wav_header(data_length)
chunk_size = struct.unpack_from("<I", header, 4)[0]
assert chunk_size == 36 + data_length
def test_wav_header_byte_rate() -> None:
header = _create_wav_header(100, sample_rate=24000, channels=1, bits_per_sample=16)
byte_rate = struct.unpack_from("<I", header, 28)[0]
assert byte_rate == 24000 * 1 * 16 // 8
def test_wav_header_produces_valid_wav() -> None:
"""Header + PCM data should parse as valid WAV."""
data_length = 100
pcm_data = b"\x00" * data_length
header = _create_wav_header(data_length, sample_rate=24000)
with wave.open(io.BytesIO(header + pcm_data), "rb") as wav_file:
assert wav_file.getnchannels() == 1
assert wav_file.getsampwidth() == 2
assert wav_file.getframerate() == 24000
assert wav_file.getnframes() == data_length // 2
# --- Provider Defaults ---
def test_provider_default_models() -> None:
provider = OpenAIVoiceProvider(api_key="test")
assert provider.stt_model == "whisper-1"
assert provider.tts_model == "tts-1"
assert provider.default_voice == "alloy"
def test_provider_custom_models() -> None:
provider = OpenAIVoiceProvider(
api_key="test",
stt_model="gpt-4o-transcribe",
tts_model="tts-1-hd",
default_voice="nova",
)
assert provider.stt_model == "gpt-4o-transcribe"
assert provider.tts_model == "tts-1-hd"
assert provider.default_voice == "nova"
def test_provider_get_available_voices_returns_copy() -> None:
provider = OpenAIVoiceProvider(api_key="test")
voices = provider.get_available_voices()
voices.clear()
assert len(provider.get_available_voices()) > 0

View File

@@ -38,6 +38,11 @@ services:
opensearch:
ports:
- "9200:9200"
# Rootless Docker can reject the base OpenSearch ulimit settings, so clear
# the inherited block entirely in the dev override.
ulimits: !reset null
environment:
- bootstrap.memory_lock=false
inference_model_server:
ports:

View File

@@ -35,6 +35,7 @@ backend = [
"alembic==1.10.4",
"asyncpg==0.30.0",
"atlassian-python-api==3.41.16",
"azure-cognitiveservices-speech==1.38.0",
"beautifulsoup4==4.12.3",
"boto3==1.39.11",
"boto3-stubs[s3]==1.39.11",
@@ -91,7 +92,7 @@ backend = [
"python-gitlab==5.6.0",
"python-pptx==0.6.23",
"pypandoc_binary==1.16.2",
"pypdf==6.7.5",
"pypdf==6.8.0",
"pytest-mock==3.12.0",
"pytest-playwright==0.7.0",
"python-docx==1.1.2",
@@ -153,7 +154,7 @@ dev = [
"pytest-repeat==0.9.4",
"pytest-xdist==3.8.0",
"pytest==8.3.5",
"release-tag==0.4.3",
"release-tag==0.5.2",
"reorder-python-imports-black==3.14.0",
"ruff==0.12.0",
"types-beautifulsoup4==4.12.0.3",

View File

@@ -0,0 +1,36 @@
package cmd
import (
"fmt"
"github.com/jmelahman/tag/git"
"github.com/spf13/cobra"
)
// NewLatestStableTagCommand creates the latest-stable-tag command.
func NewLatestStableTagCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "latest-stable-tag",
Short: "Print the git tag that should receive the 'latest' Docker tag",
Long: `Print the highest stable (non-pre-release) semver tag in the repository.
This is used during deployment to decide whether a given tag should
receive the "latest" tag on Docker Hub. Only the highest vX.Y.Z tag
qualifies. Tags with pre-release suffixes (e.g. v1.2.3-beta,
v1.2.3-cloud.1) are excluded.`,
Args: cobra.NoArgs,
RunE: func(c *cobra.Command, _ []string) error {
tag, err := git.GetLatestStableSemverTag("")
if err != nil {
return fmt.Errorf("get latest stable semver tag: %w", err)
}
if tag == "" {
return fmt.Errorf("no stable semver tag found in repository")
}
fmt.Println(tag)
return nil
},
}
return cmd
}

View File

@@ -52,6 +52,7 @@ func NewRootCommand() *cobra.Command {
cmd.AddCommand(NewScreenshotDiffCommand())
cmd.AddCommand(NewDesktopCommand())
cmd.AddCommand(NewWebCommand())
cmd.AddCommand(NewLatestStableTagCommand())
cmd.AddCommand(NewWhoisCommand())
return cmd

View File

@@ -3,12 +3,13 @@ module github.com/onyx-dot-app/onyx/tools/ods
go 1.26.0
require (
github.com/jmelahman/tag v0.5.2
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.10.1
github.com/spf13/pflag v1.0.9
github.com/spf13/cobra v1.10.2
github.com/spf13/pflag v1.0.10
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
golang.org/x/sys v0.39.0 // indirect
)

View File

@@ -4,20 +4,26 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

41
uv.lock generated
View File

@@ -463,6 +463,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
]
[[package]]
name = "azure-cognitiveservices-speech"
version = "1.38.0"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/85/f4/4571c42cb00f8af317d5431f594b4ece1fbe59ab59f106947fea8e90cf89/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:18dce915ab032711f687abb3297dd19176b9cbea562b322ee6fa7365ef4a5091", size = 6775838, upload-time = "2024-06-11T03:08:35.202Z" },
{ url = "https://files.pythonhosted.org/packages/86/22/0ca2c59a573119950cad1f53531fec9872fc38810c405a4e1827f3d13a8e/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9dd0800fbc4a8438c6dfd5747a658251914fe2d205a29e9b46158cadac6ab381", size = 6687975, upload-time = "2024-06-11T03:08:38.797Z" },
{ url = "https://files.pythonhosted.org/packages/4d/96/5436c09de3af3a9aefaa8cc00533c3a0f5d17aef5bbe017c17f0a30ad66e/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:1c344e8a6faadb063cea451f0301e13b44d9724e1242337039bff601e81e6f86", size = 40022287, upload-time = "2024-06-11T03:08:16.777Z" },
{ url = "https://files.pythonhosted.org/packages/a9/2d/ba20d05ff77ec9870cd489e6e7a474ba7fe820524bcf6fd202025e0c11cf/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1e002595a749471efeac3a54c80097946570b76c13049760b97a4b881d9d24af", size = 39788653, upload-time = "2024-06-11T03:08:30.405Z" },
{ url = "https://files.pythonhosted.org/packages/0c/21/25f8c37fb6868db4346ca977c287ede9e87f609885d932653243c9ed5f63/azure_cognitiveservices_speech-1.38.0-py3-none-win32.whl", hash = "sha256:16a530e6c646eb49ea0bc05cb45a9d28b99e4b67613f6c3a6c54e26e6bf65241", size = 1428364, upload-time = "2024-06-11T03:08:03.965Z" },
{ url = "https://files.pythonhosted.org/packages/14/05/a6414a3481c5ee30c4f32742abe055e5f3ce4ff69e936089d86ece354067/azure_cognitiveservices_speech-1.38.0-py3-none-win_amd64.whl", hash = "sha256:1d38d8c056fb3f513a9ff27ab4e77fd08ca487f8788cc7a6df772c1ab2c97b54", size = 1539297, upload-time = "2024-06-11T03:08:01.304Z" },
]
[[package]]
name = "babel"
version = "2.17.0"
@@ -4227,6 +4240,7 @@ backend = [
{ name = "asana" },
{ name = "asyncpg" },
{ name = "atlassian-python-api" },
{ name = "azure-cognitiveservices-speech" },
{ name = "beautifulsoup4" },
{ name = "boto3" },
{ name = "boto3-stubs", extra = ["s3"] },
@@ -4381,6 +4395,7 @@ requires-dist = [
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
@@ -4466,7 +4481,7 @@ requires-dist = [
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.5" },
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.8.0" },
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
@@ -4485,7 +4500,7 @@ requires-dist = [
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
@@ -5713,11 +5728,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.7.5"
version = "6.8.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" },
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
]
[[package]]
@@ -6338,16 +6353,16 @@ wheels = [
[[package]]
name = "release-tag"
version = "0.4.3"
version = "0.5.2"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
]
[[package]]

View File

@@ -144,6 +144,7 @@ module.exports = {
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
"**/src/refresh-components/**/*.test.ts",
"**/src/sections/**/*.test.ts",
"**/src/components/**/*.test.ts",
// Add more patterns here as you add more unit tests
],
},

View File

@@ -1,5 +1,6 @@
import type { Meta, StoryObj } from "@storybook/react";
import { OpenButton } from "@opal/components";
import { Disabled as DisabledProvider } from "@opal/core";
import { SvgSettings } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
@@ -32,16 +33,9 @@ export const WithIcon: Story = {
},
};
export const Selected: Story = {
args: {
selected: true,
children: "Selected",
},
};
export const Open: Story = {
args: {
transient: true,
interaction: "hover",
children: "Open state",
},
};
@@ -53,18 +47,27 @@ export const Disabled: Story = {
},
};
export const LightProminence: Story = {
export const Foldable: Story = {
args: {
prominence: "light",
children: "Light prominence",
foldable: true,
icon: SvgSettings,
children: "Settings",
},
};
export const HeavyProminence: Story = {
export const FoldableDisabled: Story = {
args: {
prominence: "heavy",
children: "Heavy prominence",
foldable: true,
icon: SvgSettings,
children: "Settings",
},
decorators: [
(Story) => (
<DisabledProvider disabled>
<Story />
</DisabledProvider>
),
],
};
export const Sizes: Story = {
@@ -78,3 +81,12 @@ export const Sizes: Story = {
</div>
),
};
export const WithTooltip: Story = {
args: {
icon: SvgSettings,
children: "Settings",
tooltip: "Open settings",
tooltipSide: "bottom",
},
};

View File

@@ -17,7 +17,9 @@ OpenButton is a **tighter, specialized use-case** of SelectButton:
- It hardcodes `variant="select-heavy"` (SelectButton exposes `variant`)
- It adds a built-in chevron with CSS-driven rotation (SelectButton has no chevron)
- It auto-detects Radix `data-state="open"` to derive `interaction` (SelectButton has no Radix awareness)
- It does not support `foldable` or `rightIcon` (SelectButton does)
- It does not support `rightIcon` (SelectButton does)
Both components support `foldable` using the same pattern: `interactive-foldable-host` class + `Interactive.Foldable` wrapper around the label and trailing icon. When foldable, the left icon stays visible while the rest collapses. If you change the foldable implementation in one, update the other to match.
If you need a general-purpose stateful toggle, use `SelectButton`. If you need a popover/dropdown trigger with a chevron, use `OpenButton`.
@@ -26,10 +28,12 @@ If you need a general-purpose stateful toggle, use `SelectButton`. If you need a
```
Interactive.Stateful <- variant="select-heavy", interaction, state, disabled, onClick
└─ Interactive.Container <- height, rounding, padding (from `size`)
└─ div.opal-button.interactive-foreground
└─ div.opal-button.interactive-foreground [.interactive-foldable-host]
├─ div > Icon? (interactive-foreground-icon)
├─ <span>? .opal-button-label
└─ div > ChevronIcon .opal-open-button-chevron (interactive-foreground-icon)
├─ [Foldable]? (wraps label + chevron when foldable)
│ ├─ <span>? .opal-button-label
│ └─ div > ChevronIcon .opal-open-button-chevron
└─ <span>? / ChevronIcon (non-foldable)
```
- **`interaction` controls both the chevron and the hover visual state.** When `interaction` is `"hover"` (explicitly or via Radix `data-state="open"`), the chevron rotates 180° and the hover background activates.
@@ -44,6 +48,7 @@ Interactive.Stateful <- variant="select-heavy", interaction, state, di
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"` when omitted. |
| `icon` | `IconFunctionComponent` | — | Left icon component |
| `children` | `string` | — | Content between icon and chevron |
| `foldable` | `boolean` | `false` | When `true`, requires both `icon` and `children`; the left icon stays visible while the label + chevron collapse when not hovered. If `tooltip` is omitted on a disabled foldable button, the label text is used as the tooltip. |
| `size` | `SizeVariant` | `"lg"` | Size preset controlling height, rounding, and padding |
| `width` | `WidthVariant` | — | Width preset |
| `tooltip` | `string` | — | Tooltip text shown on hover |

View File

@@ -2,6 +2,7 @@ import "@opal/components/buttons/open-button/styles.css";
import "@opal/components/tooltip.css";
import {
Interactive,
useDisabled,
type InteractiveStatefulProps,
type InteractiveStatefulInteraction,
} from "@opal/core";
@@ -30,27 +31,46 @@ function ChevronIcon({ className, ...props }: IconProps) {
// Types
// ---------------------------------------------------------------------------
type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
/** Left icon. */
icon?: IconFunctionComponent;
/**
* Content props — a discriminated union on `foldable` that enforces:
*
* - `foldable: true` → `icon` and `children` are required (icon stays visible,
* label + chevron fold away)
* - `foldable?: false` → at least one of `icon` or `children` must be provided
*/
type OpenButtonContentProps =
| {
foldable: true;
icon: IconFunctionComponent;
children: string;
}
| {
foldable?: false;
icon?: IconFunctionComponent;
children: string;
}
| {
foldable?: false;
icon: IconFunctionComponent;
children?: string;
};
/** Button label text. */
children?: string;
type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> &
OpenButtonContentProps & {
/**
* Size preset — controls gap, text size, and Container height/rounding.
*/
size?: SizeVariant;
/**
* Size preset — controls gap, text size, and Container height/rounding.
*/
size?: SizeVariant;
/** Width preset. */
width?: WidthVariant;
/** Width preset. */
width?: WidthVariant;
/** Tooltip text shown on hover. */
tooltip?: string;
/** Tooltip text shown on hover. */
tooltip?: string;
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
};
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
};
// ---------------------------------------------------------------------------
// OpenButton
@@ -60,12 +80,15 @@ function OpenButton({
icon: Icon,
children,
size = "lg",
foldable,
width,
tooltip,
tooltipSide = "top",
interaction,
...statefulProps
}: OpenButtonProps) {
const { isDisabled } = useDisabled();
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
| string
@@ -75,6 +98,17 @@ function OpenButton({
const isLarge = size === "lg";
const labelEl = children ? (
<span
className={cn(
"opal-button-label whitespace-nowrap",
isLarge ? "font-main-ui-body" : "font-secondary-body"
)}
>
{children}
</span>
) : null;
const button = (
<Interactive.Stateful
variant="select-heavy"
@@ -89,25 +123,34 @@ function OpenButton({
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
}
>
<div className="opal-button interactive-foreground flex flex-row items-center gap-1">
{iconWrapper(Icon, size, false)}
{children && (
<span
className={cn(
"opal-button-label whitespace-nowrap",
isLarge ? "font-main-ui-body" : "font-secondary-body"
)}
>
{children}
</span>
<div
className={cn(
"opal-button interactive-foreground flex flex-row items-center gap-1",
foldable && "interactive-foldable-host"
)}
>
{iconWrapper(Icon, size, !foldable && !!children)}
{foldable ? (
<Interactive.Foldable>
{labelEl}
{iconWrapper(ChevronIcon, size, !!children)}
</Interactive.Foldable>
) : (
<>
{labelEl}
{iconWrapper(ChevronIcon, size, !!children)}
</>
)}
{iconWrapper(ChevronIcon, size, false)}
</div>
</Interactive.Container>
</Interactive.Stateful>
);
if (!tooltip) return button;
const resolvedTooltip =
tooltip ?? (foldable && isDisabled && children ? children : undefined);
if (!resolvedTooltip) return button;
return (
<TooltipPrimitive.Root>
@@ -118,7 +161,7 @@ function OpenButton({
side={tooltipSide}
sideOffset={4}
>
{tooltip}
{resolvedTooltip}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>

View File

@@ -17,7 +17,9 @@ Interactive.Stateful → Interactive.Container → content row (icon + label + t
- OpenButton hardcodes `variant="select-heavy"` (SelectButton exposes `variant`)
- OpenButton adds a built-in chevron with CSS-driven rotation (SelectButton has no chevron)
- OpenButton auto-detects Radix `data-state="open"` to derive `interaction` (SelectButton has no Radix awareness)
- OpenButton does not support `foldable` or `rightIcon` (SelectButton does)
- OpenButton does not support `rightIcon` (SelectButton does)
Both components support `foldable` using the same pattern: `interactive-foldable-host` class + `Interactive.Foldable` wrapper around the label and trailing icon. When foldable, the left icon stays visible while the rest collapses. If you change the foldable implementation in one, update the other to match.
Use SelectButton for general-purpose stateful toggles. Use `OpenButton` for popover/dropdown triggers with a chevron.

View File

@@ -0,0 +1,87 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Card } from "@opal/components";
const BACKGROUND_VARIANTS = ["none", "light", "heavy"] as const;
const BORDER_VARIANTS = ["none", "dashed", "solid"] as const;
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
const meta: Meta<typeof Card> = {
title: "opal/components/Card",
component: Card,
tags: ["autodocs"],
};
export default meta;
type Story = StoryObj<typeof Card>;
export const Default: Story = {
render: () => (
<Card>
<p>Default card with light background, no border, lg size.</p>
</Card>
),
};
export const BackgroundVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{BACKGROUND_VARIANTS.map((bg) => (
<Card key={bg} backgroundVariant={bg} borderVariant="solid">
<p>backgroundVariant: {bg}</p>
</Card>
))}
</div>
),
};
export const BorderVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{BORDER_VARIANTS.map((border) => (
<Card key={border} borderVariant={border}>
<p>borderVariant: {border}</p>
</Card>
))}
</div>
),
};
export const SizeVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{SIZE_VARIANTS.map((size) => (
<Card key={size} sizeVariant={size} borderVariant="solid">
<p>sizeVariant: {size}</p>
</Card>
))}
</div>
),
};
export const AllCombinations: Story = {
render: () => (
<div className="flex flex-col gap-8">
{SIZE_VARIANTS.map((size) => (
<div key={size}>
<p className="font-bold pb-2">sizeVariant: {size}</p>
<div className="grid grid-cols-3 gap-4">
{BACKGROUND_VARIANTS.map((bg) =>
BORDER_VARIANTS.map((border) => (
<Card
key={`${size}-${bg}-${border}`}
sizeVariant={size}
backgroundVariant={bg}
borderVariant={border}
>
<p className="text-xs">
bg: {bg}, border: {border}
</p>
</Card>
))
)}
</div>
</div>
))}
</div>
),
};

View File

@@ -0,0 +1,67 @@
# Card
**Import:** `import { Card, type CardProps } from "@opal/components";`
A plain container component with configurable background, border, padding, and rounding. Uses a simple `<div>` internally with `overflow-clip`.
## Architecture
The `sizeVariant` controls both padding and border-radius, mirroring the same mapping used by `Button` and `Interactive.Container`:
| Size | Padding | Rounding |
|-----------|---------|----------------|
| `lg` | `p-2` | `rounded-12` |
| `md` | `p-1` | `rounded-08` |
| `sm` | `p-1` | `rounded-08` |
| `xs` | `p-0.5` | `rounded-04` |
| `2xs` | `p-0.5` | `rounded-04` |
| `fit` | `p-0` | `rounded-12` |
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `sizeVariant` | `SizeVariant` | `"lg"` | Controls padding and border-radius |
| `backgroundVariant` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
| `borderVariant` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
| `children` | `React.ReactNode` | — | Card content |
## Background Variants
- **`none`** — Transparent background. Use for seamless inline content.
- **`light`** — Subtle tinted background (`bg-background-tint-00`). The default, suitable for most cards.
- **`heavy`** — Stronger tinted background (`bg-background-tint-01`). Use for emphasis or nested cards that need visual separation.
## Border Variants
- **`none`** — No border. Use when cards are visually grouped or in tight layouts.
- **`dashed`** — Dashed border. Use for placeholder or empty states.
- **`solid`** — Solid border. Use for prominent, standalone cards.
## Usage
```tsx
import { Card } from "@opal/components";
// Default card (light background, no border, lg padding + rounding)
<Card>
<h2>Card Title</h2>
<p>Card content</p>
</Card>
// Compact card with solid border
<Card borderVariant="solid" sizeVariant="sm">
<p>Compact card</p>
</Card>
// Empty state card
<Card backgroundVariant="none" borderVariant="dashed">
<p>No items yet</p>
</Card>
// Heavy background, tight padding
<Card backgroundVariant="heavy" sizeVariant="xs">
<p>Highlighted content</p>
</Card>
```

View File

@@ -0,0 +1,101 @@
import "@opal/components/cards/card/styles.css";
import type { SizeVariant } from "@opal/shared";
import { sizeVariants } from "@opal/shared";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type BackgroundVariant = "none" | "light" | "heavy";
type BorderVariant = "none" | "dashed" | "solid";
type CardProps = {
/**
* Size preset — controls padding and border-radius.
*
* Padding comes from the shared size scale. Rounding follows the same
* mapping as `Button` / `Interactive.Container`:
*
* | Size | Rounding |
* |--------|------------|
* | `lg` | `default` |
* | `md``sm` | `compact` |
* | `xs``2xs` | `mini` |
* | `fit` | `default` |
*
* @default "lg"
*/
sizeVariant?: SizeVariant;
/**
* Background fill intensity.
* - `"none"`: transparent background.
* - `"light"`: subtle tinted background (`bg-background-tint-00`).
* - `"heavy"`: stronger tinted background (`bg-background-tint-01`).
*
* @default "light"
*/
backgroundVariant?: BackgroundVariant;
/**
* Border style.
* - `"none"`: no border.
* - `"dashed"`: dashed border.
* - `"solid"`: solid border.
*
* @default "none"
*/
borderVariant?: BorderVariant;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
children?: React.ReactNode;
};
// ---------------------------------------------------------------------------
// Rounding
// ---------------------------------------------------------------------------
/** Maps a size variant to a rounding class, mirroring the Button pattern. */
const roundingForSize: Record<SizeVariant, string> = {
lg: "rounded-12",
md: "rounded-08",
sm: "rounded-08",
xs: "rounded-04",
"2xs": "rounded-04",
fit: "rounded-12",
};
// ---------------------------------------------------------------------------
// Card
// ---------------------------------------------------------------------------
function Card({
sizeVariant = "lg",
backgroundVariant = "light",
borderVariant = "none",
ref,
children,
}: CardProps) {
const { padding } = sizeVariants[sizeVariant];
const rounding = roundingForSize[sizeVariant];
return (
<div
ref={ref}
className={cn("opal-card", padding, rounding)}
data-background={backgroundVariant}
data-border={borderVariant}
>
{children}
</div>
);
}
// ---------------------------------------------------------------------------
// Exports
// ---------------------------------------------------------------------------
export { Card, type CardProps, type BackgroundVariant, type BorderVariant };

View File

@@ -0,0 +1,29 @@
.opal-card {
@apply w-full overflow-clip;
}
/* Background variants */
.opal-card[data-background="none"] {
@apply bg-transparent;
}
.opal-card[data-background="light"] {
@apply bg-background-tint-00;
}
.opal-card[data-background="heavy"] {
@apply bg-background-tint-01;
}
/* Border variants */
.opal-card[data-border="none"] {
border: none;
}
.opal-card[data-border="dashed"] {
@apply border border-dashed;
}
.opal-card[data-border="solid"] {
@apply border;
}

View File

@@ -0,0 +1,51 @@
import type { Meta, StoryObj } from "@storybook/react";
import { EmptyMessageCard } from "@opal/components";
import { SvgSparkle, SvgUsers } from "@opal/icons";
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
const meta: Meta<typeof EmptyMessageCard> = {
title: "opal/components/EmptyMessageCard",
component: EmptyMessageCard,
tags: ["autodocs"],
};
export default meta;
type Story = StoryObj<typeof EmptyMessageCard>;
export const Default: Story = {
args: {
title: "No items available.",
},
};
export const WithCustomIcon: Story = {
args: {
icon: SvgSparkle,
title: "No agents selected.",
},
};
export const SizeVariants: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
{SIZE_VARIANTS.map((size) => (
<EmptyMessageCard
key={size}
sizeVariant={size}
title={`sizeVariant: ${size}`}
/>
))}
</div>
),
};
export const Multiple: Story = {
render: () => (
<div className="flex flex-col gap-4 w-96">
<EmptyMessageCard title="No models available." />
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
<EmptyMessageCard icon={SvgUsers} title="No groups added." />
</div>
),
};

View File

@@ -0,0 +1,30 @@
# EmptyMessageCard
**Import:** `import { EmptyMessageCard, type EmptyMessageCardProps } from "@opal/components";`
A pre-configured Card for empty states. Renders a transparent card with a dashed border containing a muted icon and message text using the `Content` layout.
## Props
| Prop | Type | Default | Description |
| ------------- | -------------------------- | ---------- | ------------------------------------------------ |
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
| `title` | `string` | — | Primary message text (required) |
| `sizeVariant` | `SizeVariant` | `"lg"` | Size preset controlling padding and rounding |
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
## Usage
```tsx
import { EmptyMessageCard } from "@opal/components";
import { SvgSparkle, SvgFileText } from "@opal/icons";
// Default empty state
<EmptyMessageCard title="No items yet." />
// With custom icon
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
// With custom size
<EmptyMessageCard sizeVariant="sm" icon={SvgFileText} title="No documents available." />
```

View File

@@ -0,0 +1,57 @@
import { Card } from "@opal/components/cards/card/components";
import { Content } from "@opal/layouts";
import { SvgEmpty } from "@opal/icons";
import type { SizeVariant } from "@opal/shared";
import type { IconFunctionComponent } from "@opal/types";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type EmptyMessageCardProps = {
/** Icon displayed alongside the title. */
icon?: IconFunctionComponent;
/** Primary message text. */
title: string;
/** Size preset controlling padding and rounding of the card. */
sizeVariant?: SizeVariant;
/** Ref forwarded to the root Card div. */
ref?: React.Ref<HTMLDivElement>;
};
// ---------------------------------------------------------------------------
// EmptyMessageCard
// ---------------------------------------------------------------------------
function EmptyMessageCard({
icon = SvgEmpty,
title,
sizeVariant = "lg",
ref,
}: EmptyMessageCardProps) {
return (
<Card
ref={ref}
backgroundVariant="none"
borderVariant="dashed"
sizeVariant={sizeVariant}
>
<Content
icon={icon}
title={title}
sizePreset="secondary"
variant="body"
prominence="muted"
/>
</Card>
);
}
// ---------------------------------------------------------------------------
// Exports
// ---------------------------------------------------------------------------
export { EmptyMessageCard, type EmptyMessageCardProps };

View File

@@ -31,3 +31,17 @@ export {
type TagProps,
type TagColor,
} from "@opal/components/tag/components";
/* Card */
export {
Card,
type CardProps,
type BackgroundVariant,
type BorderVariant,
} from "@opal/components/cards/card/components";
/* EmptyMessageCard */
export {
EmptyMessageCard,
type EmptyMessageCardProps,
} from "@opal/components/cards/empty-message-card/components";

View File

@@ -1,5 +1,5 @@
import type { Meta, StoryObj } from "@storybook/react";
import { Interactive } from "@opal/core";
import { Interactive, Disabled } from "@opal/core";
// ---------------------------------------------------------------------------
// Variant / Prominence mappings for the matrix story
@@ -9,8 +9,6 @@ const VARIANT_PROMINENCE_MAP: Record<string, string[]> = {
default: ["primary", "secondary", "tertiary", "internal"],
action: ["primary", "secondary", "tertiary", "internal"],
danger: ["primary", "secondary", "tertiary", "internal"],
select: ["light", "heavy"],
sidebar: ["light"],
none: [],
};
@@ -35,39 +33,39 @@ export default meta;
// Stories
// ---------------------------------------------------------------------------
/** Basic Interactive.Base + Container with text content. */
/** Basic Interactive.Stateless + Container with text content. */
export const Default: StoryObj = {
render: () => (
<div style={{ display: "flex", gap: "0.75rem", alignItems: "center" }}>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border>
<span>Secondary</span>
<span className="interactive-foreground">Secondary</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="primary"
onClick={() => {}}
>
<Interactive.Container border>
<span>Primary</span>
<span className="interactive-foreground">Primary</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="tertiary"
onClick={() => {}}
>
<Interactive.Container border>
<span>Tertiary</span>
<span className="interactive-foreground">Tertiary</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
</div>
),
};
@@ -91,11 +89,13 @@ export const VariantMatrix: StoryObj = {
</div>
{prominences.length === 0 ? (
<Interactive.Base variant="none" onClick={() => {}}>
<Interactive.Stateless variant="none" onClick={() => {}}>
<Interactive.Container border>
<span>none (no prominence)</span>
<span style={{ color: "var(--text-01)" }}>
none (no prominence)
</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
) : (
<div style={{ display: "flex", gap: "0.5rem", flexWrap: "wrap" }}>
{prominences.map((prominence) => (
@@ -108,16 +108,18 @@ export const VariantMatrix: StoryObj = {
gap: "0.25rem",
}}
>
<Interactive.Base
<Interactive.Stateless
// Cast required because the discriminated union can't be
// resolved from dynamic strings at the type level.
{...({ variant, prominence } as any)}
onClick={() => {}}
>
<Interactive.Container border>
<span>{prominence}</span>
<span className="interactive-foreground">
{prominence}
</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
<span
style={{
fontSize: "0.625rem",
@@ -141,16 +143,16 @@ export const Sizes: StoryObj = {
render: () => (
<div style={{ display: "flex", alignItems: "center", gap: "0.75rem" }}>
{SIZE_VARIANTS.map((size) => (
<Interactive.Base
<Interactive.Stateless
key={size}
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border heightVariant={size}>
<span>{size}</span>
<span className="interactive-foreground">{size}</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
))}
</div>
),
@@ -160,15 +162,15 @@ export const Sizes: StoryObj = {
export const WidthFull: StoryObj = {
render: () => (
<div style={{ width: 400 }}>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border widthVariant="full">
<span>Full width container</span>
<span className="interactive-foreground">Full width container</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
</div>
),
};
@@ -178,73 +180,86 @@ export const Rounding: StoryObj = {
render: () => (
<div style={{ display: "flex", gap: "0.75rem" }}>
{ROUNDING_VARIANTS.map((rounding) => (
<Interactive.Base
<Interactive.Stateless
key={rounding}
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border roundingVariant={rounding}>
<span>{rounding}</span>
<span className="interactive-foreground">{rounding}</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
))}
</div>
),
};
/** Disabled state prevents clicks and shows disabled styling. */
export const Disabled: StoryObj = {
export const DisabledStory: StoryObj = {
name: "Disabled",
render: () => (
<div style={{ display: "flex", gap: "0.75rem" }}>
<Interactive.Base
variant="default"
prominence="secondary"
onClick={() => {}}
disabled
>
<Interactive.Container border>
<span>Disabled</span>
</Interactive.Container>
</Interactive.Base>
<Disabled disabled>
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border>
<span className="interactive-foreground">Disabled</span>
</Interactive.Container>
</Interactive.Stateless>
</Disabled>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border>
<span>Enabled</span>
<span className="interactive-foreground">Enabled</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
</div>
),
};
/** Transient prop forces the hover/active visual state. */
export const Transient: StoryObj = {
/** Interaction override forces the hover/active visual state. */
export const Interaction: StoryObj = {
render: () => (
<div style={{ display: "flex", gap: "0.75rem" }}>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
interaction="hover"
onClick={() => {}}
transient
>
<Interactive.Container border>
<span>Forced hover</span>
<span className="interactive-foreground">Forced hover</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
interaction="active"
onClick={() => {}}
>
<Interactive.Container border>
<span className="interactive-foreground">Forced active</span>
</Interactive.Container>
</Interactive.Stateless>
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border>
<span>Normal</span>
<span className="interactive-foreground">Normal (rest)</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
</div>
),
};
@@ -253,25 +268,25 @@ export const Transient: StoryObj = {
export const WithBorder: StoryObj = {
render: () => (
<div style={{ display: "flex", gap: "0.75rem" }}>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container border>
<span>With border</span>
<span className="interactive-foreground">With border</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
<Interactive.Base
<Interactive.Stateless
variant="default"
prominence="secondary"
onClick={() => {}}
>
<Interactive.Container>
<span>Without border</span>
<span className="interactive-foreground">Without border</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
</div>
),
};
@@ -279,51 +294,57 @@ export const WithBorder: StoryObj = {
/** Using href to render as a link. */
export const AsLink: StoryObj = {
render: () => (
<Interactive.Base variant="action" href="/settings">
<Interactive.Stateless variant="action" href="/settings">
<Interactive.Container border>
<span>Go to Settings</span>
<span className="interactive-foreground">Go to Settings</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateless>
),
};
/** Select variant with selected and unselected states. */
/** Stateful select variant with selected and unselected states. */
export const SelectVariant: StoryObj = {
render: () => (
<div style={{ display: "flex", gap: "0.75rem" }}>
<Interactive.Base
variant="select"
prominence="light"
selected
<Interactive.Stateful
variant="select-light"
state="selected"
onClick={() => {}}
>
<Interactive.Container border>
<span>Selected (light)</span>
<span className="interactive-foreground">Selected (light)</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateful>
<Interactive.Base variant="select" prominence="light" onClick={() => {}}>
<Interactive.Container border>
<span>Unselected (light)</span>
</Interactive.Container>
</Interactive.Base>
<Interactive.Base
variant="select"
prominence="heavy"
selected
<Interactive.Stateful
variant="select-light"
state="empty"
onClick={() => {}}
>
<Interactive.Container border>
<span>Selected (heavy)</span>
<span className="interactive-foreground">Unselected (light)</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateful>
<Interactive.Base variant="select" prominence="heavy" onClick={() => {}}>
<Interactive.Stateful
variant="select-heavy"
state="selected"
onClick={() => {}}
>
<Interactive.Container border>
<span>Unselected (heavy)</span>
<span className="interactive-foreground">Selected (heavy)</span>
</Interactive.Container>
</Interactive.Base>
</Interactive.Stateful>
<Interactive.Stateful
variant="select-heavy"
state="empty"
onClick={() => {}}
>
<Interactive.Container border>
<span className="interactive-foreground">Unselected (heavy)</span>
</Interactive.Container>
</Interactive.Stateful>
</div>
),
};

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgAudio = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 10V6M5 14V2M11 11V5M14 9V7M8 10V6"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgAudio;

View File

@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
export { default as SvgAudio } from "@opal/icons/audio";
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
export { default as SvgAws } from "@opal/icons/aws";
export { default as SvgAzure } from "@opal/icons/azure";
@@ -89,7 +90,7 @@ export { default as SvgHistory } from "@opal/icons/history";
export { default as SvgHourglass } from "@opal/icons/hourglass";
export { default as SvgImage } from "@opal/icons/image";
export { default as SvgImageSmall } from "@opal/icons/image-small";
export { default as SvgImport } from "@opal/icons/import";
export { default as SvgImport } from "@opal/icons/import-icon";
export { default as SvgInfo } from "@opal/icons/info";
export { default as SvgInfoSmall } from "@opal/icons/info-small";
export { default as SvgKey } from "@opal/icons/key";
@@ -106,6 +107,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
export { default as SvgMcp } from "@opal/icons/mcp";
export { default as SvgMenu } from "@opal/icons/menu";
export { default as SvgMicrophone } from "@opal/icons/microphone";
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
export { default as SvgMinus } from "@opal/icons/minus";
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
export { default as SvgMoon } from "@opal/icons/moon";
@@ -176,6 +179,8 @@ export { default as SvgUserManage } from "@opal/icons/user-manage";
export { default as SvgUserPlus } from "@opal/icons/user-plus";
export { default as SvgUserSync } from "@opal/icons/user-sync";
export { default as SvgUsers } from "@opal/icons/users";
export { default as SvgVolume } from "@opal/icons/volume";
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
export { default as SvgWallet } from "@opal/icons/wallet";
export { default as SvgWorkflow } from "@opal/icons/workflow";
export { default as SvgX } from "@opal/icons/x";

View File

@@ -0,0 +1,29 @@
import type { IconProps } from "@opal/types";
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
{/* Microphone body */}
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
{/* Diagonal slash */}
<path
d="M2 2L14 14"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophoneOff;

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgMicrophone = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophone;

View File

@@ -0,0 +1,26 @@
import type { IconProps } from "@opal/types";
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 6V10H5L9 13V3L5 6H2Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M14 6L11 9M11 6L14 9"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgVolumeOff;

View File

@@ -0,0 +1,26 @@
import type { IconProps } from "@opal/types";
const SvgVolume = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 6V10H5L9 13V3L5 6H2Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgVolume;

View File

@@ -1,87 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react";
import { BodyLayout } from "./BodyLayout";
import { SvgSettings, SvgStar, SvgRefreshCw } from "@opal/icons";
const meta = {
title: "Layouts/BodyLayout",
component: BodyLayout,
tags: ["autodocs"],
parameters: {
layout: "centered",
},
} satisfies Meta<typeof BodyLayout>;
export default meta;
type Story = StoryObj<typeof meta>;
// ---------------------------------------------------------------------------
// Size presets
// ---------------------------------------------------------------------------
export const MainContent: Story = {
args: {
sizePreset: "main-content",
title: "Last synced 2 minutes ago",
},
};
export const MainUi: Story = {
args: {
sizePreset: "main-ui",
title: "Document count: 1,234",
},
};
export const Secondary: Story = {
args: {
sizePreset: "secondary",
title: "Updated 5 min ago",
},
};
// ---------------------------------------------------------------------------
// With icon
// ---------------------------------------------------------------------------
export const WithIcon: Story = {
args: {
sizePreset: "main-ui",
title: "Settings",
icon: SvgSettings,
},
};
// ---------------------------------------------------------------------------
// Orientations
// ---------------------------------------------------------------------------
export const Vertical: Story = {
args: {
sizePreset: "main-ui",
title: "Stacked layout",
icon: SvgStar,
orientation: "vertical",
},
};
export const Reverse: Story = {
args: {
sizePreset: "main-ui",
title: "Reverse layout",
icon: SvgRefreshCw,
orientation: "reverse",
},
};
// ---------------------------------------------------------------------------
// Prominence
// ---------------------------------------------------------------------------
export const Muted: Story = {
args: {
sizePreset: "main-ui",
title: "Muted body text",
prominence: "muted",
},
};

View File

@@ -1,98 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react";
import { HeadingLayout } from "./HeadingLayout";
import { SvgSettings, SvgStar } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
const meta = {
title: "Layouts/HeadingLayout",
component: HeadingLayout,
tags: ["autodocs"],
parameters: {
layout: "centered",
},
decorators: [
(Story) => (
<TooltipPrimitive.Provider>
<Story />
</TooltipPrimitive.Provider>
),
],
} satisfies Meta<typeof HeadingLayout>;
export default meta;
type Story = StoryObj<typeof meta>;
// ---------------------------------------------------------------------------
// Size presets
// ---------------------------------------------------------------------------
export const Headline: Story = {
args: {
sizePreset: "headline",
title: "Welcome to Onyx",
description: "Your enterprise search and AI assistant platform.",
},
};
export const Section: Story = {
args: {
sizePreset: "section",
title: "Configuration",
},
};
// ---------------------------------------------------------------------------
// With icon
// ---------------------------------------------------------------------------
export const WithIcon: Story = {
args: {
sizePreset: "headline",
title: "Settings",
icon: SvgSettings,
},
};
export const SectionWithIcon: Story = {
args: {
sizePreset: "section",
variant: "section",
title: "Favorites",
icon: SvgStar,
},
};
// ---------------------------------------------------------------------------
// Variants
// ---------------------------------------------------------------------------
export const SectionVariant: Story = {
args: {
sizePreset: "headline",
variant: "section",
title: "Inline Icon Heading",
icon: SvgSettings,
},
};
// ---------------------------------------------------------------------------
// Editable
// ---------------------------------------------------------------------------
export const Editable: Story = {
args: {
sizePreset: "headline",
title: "Click to edit me",
editable: true,
},
};
export const EditableSection: Story = {
args: {
sizePreset: "section",
title: "Editable Section Title",
editable: true,
description: "This title can be edited inline.",
},
};

View File

@@ -1,154 +0,0 @@
import type { Meta, StoryObj } from "@storybook/react";
import { LabelLayout } from "./LabelLayout";
import { SvgSettings, SvgStar } from "@opal/icons";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
const meta = {
title: "Layouts/LabelLayout",
component: LabelLayout,
tags: ["autodocs"],
parameters: {
layout: "centered",
},
decorators: [
(Story) => (
<TooltipPrimitive.Provider>
<Story />
</TooltipPrimitive.Provider>
),
],
} satisfies Meta<typeof LabelLayout>;
export default meta;
type Story = StoryObj<typeof meta>;
// ---------------------------------------------------------------------------
// Size presets
// ---------------------------------------------------------------------------
export const MainContent: Story = {
args: {
sizePreset: "main-content",
title: "Display Name",
},
};
export const MainUi: Story = {
args: {
sizePreset: "main-ui",
title: "Email Address",
},
};
export const SecondaryPreset: Story = {
args: {
sizePreset: "secondary",
title: "API Key",
},
};
// ---------------------------------------------------------------------------
// With description
// ---------------------------------------------------------------------------
export const WithDescription: Story = {
args: {
sizePreset: "main-content",
title: "Workspace Name",
description: "The name displayed across your organization.",
},
};
// ---------------------------------------------------------------------------
// With icon
// ---------------------------------------------------------------------------
export const WithIcon: Story = {
args: {
sizePreset: "main-ui",
title: "Settings",
icon: SvgSettings,
},
};
// ---------------------------------------------------------------------------
// Optional
// ---------------------------------------------------------------------------
export const Optional: Story = {
args: {
sizePreset: "main-content",
title: "Phone Number",
optional: true,
},
};
// ---------------------------------------------------------------------------
// Aux icons
// ---------------------------------------------------------------------------
export const AuxInfoGray: Story = {
args: {
sizePreset: "main-content",
title: "Connection Status",
auxIcon: "info-gray",
},
};
export const AuxWarning: Story = {
args: {
sizePreset: "main-content",
title: "Rate Limit",
auxIcon: "warning",
},
};
export const AuxError: Story = {
args: {
sizePreset: "main-content",
title: "API Key",
auxIcon: "error",
},
};
// ---------------------------------------------------------------------------
// With tag
// ---------------------------------------------------------------------------
export const WithTag: Story = {
args: {
sizePreset: "main-ui",
title: "Knowledge Graph",
tag: { title: "Beta", color: "blue" },
},
};
// ---------------------------------------------------------------------------
// Editable
// ---------------------------------------------------------------------------
export const Editable: Story = {
args: {
sizePreset: "main-ui",
title: "Click to edit",
editable: true,
},
};
// ---------------------------------------------------------------------------
// Combined
// ---------------------------------------------------------------------------
export const FullFeatured: Story = {
args: {
sizePreset: "main-content",
title: "Custom Field",
icon: SvgStar,
description: "A custom field with all extras enabled.",
optional: true,
auxIcon: "info-blue",
tag: { title: "New", color: "green" },
editable: true,
},
};

View File

@@ -1,134 +0,0 @@
"use client";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type BodySizePreset = "main-content" | "main-ui" | "secondary";
type BodyOrientation = "vertical" | "inline" | "reverse";
type BodyProminence = "default" | "muted";
interface BodyPresetConfig {
/** Icon width/height (CSS value). */
iconSize: string;
/** Tailwind padding class for the icon container. */
iconContainerPadding: string;
/** Tailwind font class for the title. */
titleFont: string;
/** Title line-height — also used as icon container min-height (CSS value). */
lineHeight: string;
/** Gap between icon container and title (CSS value). */
gap: string;
}
/** Props for {@link BodyLayout}. Does not support editing or descriptions. */
interface BodyLayoutProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text (read-only — editing is not supported). */
title: string;
/** Size preset. Default: `"main-ui"`. */
sizePreset?: BodySizePreset;
/** Layout orientation. Default: `"inline"`. */
orientation?: BodyOrientation;
/** Title prominence. Default: `"default"`. */
prominence?: BodyProminence;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const BODY_PRESETS: Record<BodySizePreset, BodyPresetConfig> = {
"main-content": {
iconSize: "1rem",
iconContainerPadding: "p-1",
titleFont: "font-main-content-body",
lineHeight: "1.5rem",
gap: "0.125rem",
},
"main-ui": {
iconSize: "1rem",
iconContainerPadding: "p-0.5",
titleFont: "font-main-ui-action",
lineHeight: "1.25rem",
gap: "0.25rem",
},
secondary: {
iconSize: "0.75rem",
iconContainerPadding: "p-0.5",
titleFont: "font-secondary-action",
lineHeight: "1rem",
gap: "0.125rem",
},
};
// ---------------------------------------------------------------------------
// BodyLayout
// ---------------------------------------------------------------------------
function BodyLayout({
icon: Icon,
title,
sizePreset = "main-ui",
orientation = "inline",
prominence = "default",
ref,
}: BodyLayoutProps) {
const config = BODY_PRESETS[sizePreset];
const titleColorClass =
prominence === "muted" ? "text-text-03" : "text-text-04";
return (
<div
ref={ref}
className="opal-content-body"
data-orientation={orientation}
style={{ gap: config.gap }}
>
{Icon && (
<div
className={cn(
"opal-content-body-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className="opal-content-body-icon text-text-03"
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<span
className={cn(
"opal-content-body-title",
config.titleFont,
titleColorClass
)}
style={{ height: config.lineHeight }}
>
{title}
</span>
</div>
);
}
export {
BodyLayout,
type BodyLayoutProps,
type BodySizePreset,
type BodyOrientation,
type BodyProminence,
};

View File

@@ -1,218 +0,0 @@
"use client";
import { Button } from "@opal/components/buttons/button/components";
import type { SizeVariant } from "@opal/shared";
import SvgEdit from "@opal/icons/edit";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type HeadingSizePreset = "headline" | "section";
type HeadingVariant = "heading" | "section";
interface HeadingPresetConfig {
/** Icon width/height (CSS value). */
iconSize: string;
/** Tailwind padding class for the icon container. */
iconContainerPadding: string;
/** Gap between icon container and content (CSS value). */
gap: string;
/** Tailwind font class for the title. */
titleFont: string;
/** Title line-height — also used as icon container min-height (CSS value). */
lineHeight: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: SizeVariant;
/** Tailwind padding class for the edit button container. */
editButtonPadding: string;
}
interface HeadingLayoutProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
/** Optional description below the title. */
description?: string;
/** Enable inline editing of the title. */
editable?: boolean;
/** Called when the user commits an edit. */
onTitleChange?: (newTitle: string) => void;
/** Size preset. Default: `"headline"`. */
sizePreset?: HeadingSizePreset;
/** Variant controls icon placement. `"heading"` = top, `"section"` = inline. Default: `"heading"`. */
variant?: HeadingVariant;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const HEADING_PRESETS: Record<HeadingSizePreset, HeadingPresetConfig> = {
headline: {
iconSize: "2rem",
iconContainerPadding: "p-0.5",
gap: "0.25rem",
titleFont: "font-heading-h2",
lineHeight: "2.25rem",
editButtonSize: "md",
editButtonPadding: "p-1",
},
section: {
iconSize: "1.25rem",
iconContainerPadding: "p-1",
gap: "0rem",
titleFont: "font-heading-h3",
lineHeight: "1.75rem",
editButtonSize: "sm",
editButtonPadding: "p-0.5",
},
};
// ---------------------------------------------------------------------------
// HeadingLayout
// ---------------------------------------------------------------------------
function HeadingLayout({
sizePreset = "headline",
variant = "heading",
icon: Icon,
title,
description,
editable,
onTitleChange,
ref,
}: HeadingLayoutProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const inputRef = useRef<HTMLInputElement>(null);
const config = HEADING_PRESETS[sizePreset];
const iconPlacement = variant === "heading" ? "top" : "left";
function startEditing() {
setEditValue(title);
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
setEditing(false);
}
return (
<div
ref={ref}
className="opal-content-heading"
data-icon-placement={iconPlacement}
style={{ gap: iconPlacement === "left" ? config.gap : undefined }}
>
{Icon && (
<div
className={cn(
"opal-content-heading-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className="opal-content-heading-icon"
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div className="opal-content-heading-body">
<div className="opal-content-heading-title-row">
{editing ? (
<div className="opal-content-heading-input-sizer">
<span
className={cn(
"opal-content-heading-input-mirror",
config.titleFont
)}
>
{editValue || "\u00A0"}
</span>
<input
ref={inputRef}
className={cn(
"opal-content-heading-input",
config.titleFont,
"text-text-04"
)}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
size={1}
autoFocus
onFocus={(e) => e.currentTarget.select()}
onBlur={commit}
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditing(false);
}
}}
style={{ height: config.lineHeight }}
/>
</div>
) : (
<span
className={cn(
"opal-content-heading-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
</span>
)}
{editable && !editing && (
<div
className={cn(
"opal-content-heading-edit-button",
config.editButtonPadding
)}
>
<Button
icon={SvgEdit}
prominence="internal"
size={config.editButtonSize}
tooltip="Edit"
tooltipSide="right"
onClick={startEditing}
/>
</div>
)}
</div>
{description && (
<div className="opal-content-heading-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
</div>
);
}
export { HeadingLayout, type HeadingLayoutProps, type HeadingSizePreset };

View File

@@ -1,286 +0,0 @@
"use client";
import { Button } from "@opal/components/buttons/button/components";
import { Tag, type TagProps } from "@opal/components/tag/components";
import type { SizeVariant } from "@opal/shared";
import SvgAlertCircle from "@opal/icons/alert-circle";
import SvgAlertTriangle from "@opal/icons/alert-triangle";
import SvgEdit from "@opal/icons/edit";
import SvgXOctagon from "@opal/icons/x-octagon";
import type { IconFunctionComponent } from "@opal/types";
import { cn } from "@opal/utils";
import { useRef, useState } from "react";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
type LabelSizePreset = "main-content" | "main-ui" | "secondary";
type LabelAuxIcon = "info-gray" | "info-blue" | "warning" | "error";
interface LabelPresetConfig {
iconSize: string;
iconContainerPadding: string;
iconColorClass: string;
titleFont: string;
lineHeight: string;
gap: string;
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
editButtonSize: SizeVariant;
editButtonPadding: string;
optionalFont: string;
/** Aux icon size = lineHeight 2 × p-0.5. */
auxIconSize: string;
}
interface LabelLayoutProps {
/** Optional icon component. */
icon?: IconFunctionComponent;
/** Main title text. */
title: string;
/** Optional description text below the title. */
description?: string;
/** Enable inline editing of the title. */
editable?: boolean;
/** Called when the user commits an edit. */
onTitleChange?: (newTitle: string) => void;
/** When `true`, renders "(Optional)" beside the title. */
optional?: boolean;
/** Auxiliary status icon rendered beside the title. */
auxIcon?: LabelAuxIcon;
/** Tag rendered beside the title. */
tag?: TagProps;
/** Size preset. Default: `"main-ui"`. */
sizePreset?: LabelSizePreset;
/** Ref forwarded to the root `<div>`. */
ref?: React.Ref<HTMLDivElement>;
}
// ---------------------------------------------------------------------------
// Presets
// ---------------------------------------------------------------------------
const LABEL_PRESETS: Record<LabelSizePreset, LabelPresetConfig> = {
"main-content": {
iconSize: "1rem",
iconContainerPadding: "p-1",
iconColorClass: "text-text-04",
titleFont: "font-main-content-emphasis",
lineHeight: "1.5rem",
gap: "0.125rem",
editButtonSize: "sm",
editButtonPadding: "p-0",
optionalFont: "font-main-content-muted",
auxIconSize: "1.25rem",
},
"main-ui": {
iconSize: "1rem",
iconContainerPadding: "p-0.5",
iconColorClass: "text-text-03",
titleFont: "font-main-ui-action",
lineHeight: "1.25rem",
gap: "0.25rem",
editButtonSize: "xs",
editButtonPadding: "p-0",
optionalFont: "font-main-ui-muted",
auxIconSize: "1rem",
},
secondary: {
iconSize: "0.75rem",
iconContainerPadding: "p-0.5",
iconColorClass: "text-text-04",
titleFont: "font-secondary-action",
lineHeight: "1rem",
gap: "0.125rem",
editButtonSize: "2xs",
editButtonPadding: "p-0",
optionalFont: "font-secondary-action",
auxIconSize: "0.75rem",
},
};
// ---------------------------------------------------------------------------
// LabelLayout
// ---------------------------------------------------------------------------
const AUX_ICON_CONFIG: Record<
LabelAuxIcon,
{ icon: IconFunctionComponent; colorClass: string }
> = {
"info-gray": { icon: SvgAlertCircle, colorClass: "text-text-02" },
"info-blue": { icon: SvgAlertCircle, colorClass: "text-status-info-05" },
warning: { icon: SvgAlertTriangle, colorClass: "text-status-warning-05" },
error: { icon: SvgXOctagon, colorClass: "text-status-error-05" },
};
function LabelLayout({
icon: Icon,
title,
description,
editable,
onTitleChange,
optional,
auxIcon,
tag,
sizePreset = "main-ui",
ref,
}: LabelLayoutProps) {
const [editing, setEditing] = useState(false);
const [editValue, setEditValue] = useState(title);
const inputRef = useRef<HTMLInputElement>(null);
const config = LABEL_PRESETS[sizePreset];
function startEditing() {
setEditValue(title);
setEditing(true);
}
function commit() {
const value = editValue.trim();
if (value && value !== title) onTitleChange?.(value);
setEditing(false);
}
return (
<div ref={ref} className="opal-content-label" style={{ gap: config.gap }}>
{Icon && (
<div
className={cn(
"opal-content-label-icon-container shrink-0",
config.iconContainerPadding
)}
style={{ minHeight: config.lineHeight }}
>
<Icon
className={cn("opal-content-label-icon", config.iconColorClass)}
style={{ width: config.iconSize, height: config.iconSize }}
/>
</div>
)}
<div className="opal-content-label-body">
<div className="opal-content-label-title-row">
{editing ? (
<div className="opal-content-label-input-sizer">
<span
className={cn(
"opal-content-label-input-mirror",
config.titleFont
)}
>
{editValue || "\u00A0"}
</span>
<input
ref={inputRef}
className={cn(
"opal-content-label-input",
config.titleFont,
"text-text-04"
)}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
size={1}
autoFocus
onFocus={(e) => e.currentTarget.select()}
onBlur={commit}
onKeyDown={(e) => {
if (e.key === "Enter") commit();
if (e.key === "Escape") {
setEditValue(title);
setEditing(false);
}
}}
style={{ height: config.lineHeight }}
/>
</div>
) : (
<span
className={cn(
"opal-content-label-title",
config.titleFont,
"text-text-04",
editable && "cursor-pointer"
)}
onClick={editable ? startEditing : undefined}
style={{ height: config.lineHeight }}
>
{title}
</span>
)}
{optional && (
<span
className={cn(config.optionalFont, "text-text-03 shrink-0")}
style={{ height: config.lineHeight }}
>
(Optional)
</span>
)}
{auxIcon &&
(() => {
const { icon: AuxIcon, colorClass } = AUX_ICON_CONFIG[auxIcon];
return (
<div
className="opal-content-label-aux-icon shrink-0 p-0.5"
style={{ height: config.lineHeight }}
>
<AuxIcon
className={colorClass}
style={{
width: config.auxIconSize,
height: config.auxIconSize,
}}
/>
</div>
);
})()}
{tag && <Tag {...tag} />}
{editable && !editing && (
<div
className={cn(
"opal-content-label-edit-button",
config.editButtonPadding
)}
>
<Button
icon={SvgEdit}
prominence="internal"
size={config.editButtonSize}
tooltip="Edit"
tooltipSide="right"
onClick={startEditing}
/>
</div>
)}
</div>
{description && (
<div className="opal-content-label-description font-secondary-body text-text-03">
{description}
</div>
)}
</div>
</div>
);
}
export {
LabelLayout,
type LabelLayoutProps,
type LabelSizePreset,
type LabelAuxIcon,
};

View File

@@ -59,7 +59,7 @@ const nextConfig = {
{
key: "Permissions-Policy",
value:
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
},
],
},

View File

@@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.5 2H13V14H10.5V2Z" fill="currentColor"/>
<path d="M3 2H5.5V14H3V2Z" fill="currentColor"/>
</svg>

After

Width:  |  Height:  |  Size: 206 B

Some files were not shown because too many files have changed in this diff Show More