Compare commits

..

17 Commits

Author SHA1 Message Date
Jessica Singh
bba77749c3 fix: position recording bar above input on subsequent turns 2026-03-05 14:33:23 -08:00
Jessica Singh
3e9a66c8ff chore: add @types/sbd for TypeScript support 2026-03-05 14:25:12 -08:00
Jessica Singh
548b9d9e0e fix: remove unused type ignore in azure.py 2026-03-05 14:13:42 -08:00
Jessica Singh
0d3967baee mypy 2026-03-05 13:50:27 -08:00
Jessica Singh
6ed806eebb migration 2026-03-05 09:32:35 -08:00
Jessica Singh
3b6a35b2c4 recording bar + bug fixes 2026-03-04 21:55:42 -08:00
Jessica Singh
62e612f85f Merge branch 'main' into voice-mode 2026-03-04 21:25:37 -08:00
Jessica Singh
b375b7f0ff azure 2026-03-04 20:14:18 -08:00
Jessica Singh
c158ae2622 remove logs 2026-03-04 17:02:44 -08:00
Jessica Singh
698494626f eleven labs and bug fixes 2026-03-04 16:40:26 -08:00
Jessica Singh
93cefe7ef0 chore: trigger Greptile review 2026-03-03 23:04:53 -08:00
Jessica Singh
8a326c4089 address greptile review feedback (greploop iteration 2)
- Narrow WebSocket auth bypass to only voice endpoints in auth_check.py
- Add query param validation (max_length, ge/le) for TTS synthesize endpoint
- Fix ObjectURL memory leak in useVoicePlayback.ts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:39:12 -08:00
Jessica Singh
0c5410f429 address greptile review feedback (greploop iteration 1)
- Add WebSocket authentication to /voice/transcribe/stream and /voice/synthesize/stream endpoints
- Fix useVoicePlayback.ts to use query params instead of JSON body (matches API signature)
- Fix delete_voice_provider to use flush() instead of commit() for consistency
- Disable Azure streaming STT until audio resampling is implemented
- Add SSML escaping to prevent injection in Azure TTS
- Remove debug console.log statements from voice components
- Fix blob URL memory leak in VoiceModeProvider

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:31:59 -08:00
Jessica Singh
0b05b9b235 fix(voice): move error toast to useEffect 2026-03-03 17:26:20 -08:00
Jessica Singh
59d8a988bd streaming tts 2026-03-03 03:37:18 -08:00
Jessica Singh
6d08cfb25a all changes 2026-03-02 13:16:39 -08:00
Jessica Singh
53a5ee2a6e stt and tts 2026-02-23 18:27:37 -08:00
57 changed files with 8088 additions and 307 deletions

View File

@@ -0,0 +1,119 @@
"""add_voice_provider_and_user_voice_prefs
Revision ID: 93a2e195e25c
Revises: 631fd2504136
Create Date: 2026-02-23 15:16:39.507304
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "93a2e195e25c"
down_revision = "a3b8d9e2f1c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create voice_provider table
op.create_table(
"voice_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), unique=True, nullable=False),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
sa.Column("stt_model", sa.String(), nullable=True),
sa.Column("tts_model", sa.String(), nullable=True),
sa.Column("default_voice", sa.String(), nullable=True),
sa.Column(
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
nullable=False,
),
)
# Add partial unique indexes to enforce only one default STT/TTS provider
op.execute(
"""
CREATE UNIQUE INDEX ix_voice_provider_one_default_stt
ON voice_provider (is_default_stt)
WHERE is_default_stt = true
"""
)
op.execute(
"""
CREATE UNIQUE INDEX ix_voice_provider_one_default_tts
ON voice_provider (is_default_tts)
WHERE is_default_tts = true
"""
)
# Add voice preference columns to user table
op.add_column(
"user",
sa.Column(
"voice_auto_send",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_auto_playback",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_playback_speed",
sa.Float(),
default=1.0,
nullable=False,
server_default="1.0",
),
)
op.add_column(
"user",
sa.Column("preferred_voice", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove user voice preference columns
op.drop_column("user", "preferred_voice")
op.drop_column("user", "voice_playback_speed")
op.drop_column("user", "voice_auto_playback")
op.drop_column("user", "voice_auto_send")
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_tts")
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_stt")
# Drop voice_provider table
op.drop_table("voice_provider")

View File

@@ -28,6 +28,7 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi import WebSocket
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
@@ -1599,6 +1600,91 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
return user
async def _get_user_from_token_data(token_data: dict) -> User | None:
"""Shared logic: token data dict → User object.
Args:
token_data: Decoded token data containing 'sub' (user ID).
Returns:
User object if found and active, None otherwise.
"""
user_id = token_data.get("sub")
if not user_id:
return None
try:
user_uuid = uuid.UUID(user_id)
except ValueError:
return None
async with get_async_session_context_manager() as async_db_session:
user = await async_db_session.get(User, user_uuid)
if user is None or not user.is_active:
return None
return user
async def current_user_from_websocket(
_websocket: WebSocket,
token: str = Query(..., description="WebSocket authentication token"),
) -> User:
"""
WebSocket authentication dependency using query parameter.
Validates the WS token from query param and returns the User.
Raises BasicAuthenticationError if authentication fails.
The token must be obtained from POST /voice/ws-token before connecting.
Tokens are single-use and expire after 60 seconds.
Usage:
1. POST /voice/ws-token -> {"token": "xxx"}
2. Connect to ws://host/path?token=xxx
This applies the same auth checks as current_user() for HTTP endpoints.
"""
from onyx.redis.redis_pool import retrieve_ws_token_data
# Validate WS token in Redis (single-use, deleted after retrieval)
try:
token_data = await retrieve_ws_token_data(token)
if token_data is None:
raise BasicAuthenticationError(
detail="Access denied. Invalid or expired authentication token."
)
except BasicAuthenticationError:
raise
except Exception as e:
logger.error(f"WS auth: error during token validation: {e}")
raise BasicAuthenticationError(
detail=f"Authentication verification failed: {str(e)}"
) from e
# Get user from token data
user = await _get_user_from_token_data(token_data)
if user is None:
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
raise BasicAuthenticationError(
detail="Access denied. User not found or inactive."
)
logger.info(f"WS auth: user found: {user.email}")
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
user = await double_check_user(user)
logger.info(f"WS auth: user verified: {user.email}, role={user.role}")
# Block LIMITED users (same as current_user)
if user.role == UserRole.LIMITED:
logger.warning(f"WS auth: user {user.email} has LIMITED role")
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
logger.info(f"WS auth: authentication successful for {user.email}")
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Onyx MIT
return []

View File

@@ -284,6 +284,12 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# Voice preferences
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
preferred_voice: Mapped[str | None] = mapped_column(String, nullable=True)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
@@ -2964,6 +2970,63 @@ class ImageGenerationConfig(Base):
)
class VoiceProvider(Base):
"""Configuration for voice services (STT and TTS)."""
__tablename__ = "voice_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider_type: Mapped[str] = mapped_column(
String
) # "openai", "azure", "elevenlabs"
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Model/voice configuration
stt_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "whisper-1"
tts_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "tts-1", "tts-1-hd"
default_voice: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "alloy", "echo"
# STT and TTS can use different providers - only one provider per type
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
# Enforce only one default STT provider and one default TTS provider at DB level
__table_args__ = (
Index(
"ix_voice_provider_one_default_stt",
"is_default_stt",
unique=True,
postgresql_where=(is_default_stt == True), # noqa: E712
),
Index(
"ix_voice_provider_one_default_tts",
"is_default_tts",
unique=True,
postgresql_where=(is_default_tts == True), # noqa: E712
),
)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"

257
backend/onyx/db/voice.py Normal file
View File

@@ -0,0 +1,257 @@
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import VoiceProvider
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
def fetch_voice_provider_by_type(
db_session: Session, provider_type: str
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
def upsert_voice_provider(
*,
db_session: Session,
provider_id: int | None,
name: str,
provider_type: str,
api_key: str | None,
api_key_changed: bool,
api_base: str | None = None,
custom_config: dict[str, Any] | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
activate_stt: bool = False,
activate_tts: bool = False,
) -> VoiceProvider:
"""Create or update a voice provider."""
provider: VoiceProvider | None = None
if provider_id is not None:
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
else:
provider = VoiceProvider()
db_session.add(provider)
# Apply updates
provider.name = name
provider.provider_type = provider_type
provider.api_base = api_base
provider.custom_config = custom_config
provider.stt_model = stt_model
provider.tts_model = tts_model
provider.default_voice = default_voice
# Only update API key if explicitly changed or if provider has no key
if api_key_changed or provider.api_key is None:
provider.api_key = api_key # type: ignore[assignment]
db_session.flush()
if activate_stt:
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
if activate_tts:
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
db_session.refresh(provider)
return provider
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
db_session.flush()
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Set a voice provider as the default STT provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other STT providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_stt.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_stt=False)
)
# Activate this provider
provider.is_default_stt = True
db_session.flush()
db_session.refresh(provider)
return provider
def set_default_tts_provider(
*, db_session: Session, provider_id: int, tts_model: str | None = None
) -> VoiceProvider:
"""Set a voice provider as the default TTS provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other TTS providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_tts.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_tts=False)
)
# Activate this provider
provider.is_default_tts = True
# Update the TTS model if specified
if tts_model is not None:
provider.tts_model = tts_model
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default STT status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_stt = False
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default TTS status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_tts = False
db_session.flush()
db_session.refresh(provider)
return provider
# User voice preferences
def update_user_voice_auto_send(
db_session: Session, user_id: UUID, auto_send: bool
) -> None:
"""Update user's voice auto-send setting."""
db_session.execute(
update(User).where(User.id == user_id).values(voice_auto_send=auto_send) # type: ignore[arg-type]
)
db_session.commit()
def update_user_voice_auto_playback(
db_session: Session, user_id: UUID, auto_playback: bool
) -> None:
"""Update user's voice auto-playback setting."""
db_session.execute(
update(User).where(User.id == user_id).values(voice_auto_playback=auto_playback) # type: ignore[arg-type]
)
db_session.commit()
def update_user_voice_playback_speed(
db_session: Session, user_id: UUID, speed: float
) -> None:
"""Update user's voice playback speed setting."""
# Clamp to valid range
speed = max(0.5, min(2.0, speed))
db_session.execute(
update(User).where(User.id == user_id).values(voice_playback_speed=speed) # type: ignore[arg-type]
)
db_session.commit()
def update_user_preferred_voice(
db_session: Session, user_id: UUID, voice: str | None
) -> None:
"""Update user's preferred voice setting."""
db_session.execute(
update(User).where(User.id == user_id).values(preferred_voice=voice) # type: ignore[arg-type]
)
db_session.commit()
def update_user_voice_settings(
db_session: Session,
user_id: UUID,
auto_send: bool | None = None,
auto_playback: bool | None = None,
playback_speed: float | None = None,
preferred_voice: str | None = None,
) -> None:
"""Update user's voice settings. Only updates fields that are not None."""
values: dict[str, Any] = {}
if auto_send is not None:
values["voice_auto_send"] = auto_send
if auto_playback is not None:
values["voice_auto_playback"] = auto_playback
if playback_speed is not None:
values["voice_playback_speed"] = max(0.5, min(2.0, playback_speed))
if preferred_voice is not None:
values["preferred_voice"] = preferred_voice
if values:
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
db_session.commit()

View File

@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.voice.api import admin_router as voice_admin_router
from onyx.server.manage.voice.user_api import router as voice_router
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(application, voice_admin_router)
include_router_with_global_prefix_prepended(application, voice_router)
include_router_with_global_prefix_prepended(application, voice_websocket_router)
include_router_with_global_prefix_prepended(
application, opensearch_migration_admin_router
)

View File

@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
return _async_redis_connection
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
async def retrieve_auth_token_data(token: str) -> dict | None:
"""Validate auth token against Redis and return token data.
Args:
token: The raw authentication token string.
Returns:
Token data dict if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
@@ -439,12 +442,65 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
logger.error("Error decoding token data from Redis")
return None
except Exception as e:
logger.error(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
raise ValueError(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
return await retrieve_auth_token_data(token)
# WebSocket token prefix (separate from regular auth tokens)
REDIS_WS_TOKEN_PREFIX = "ws_token:"
# WebSocket tokens expire after 60 seconds
WS_TOKEN_TTL_SECONDS = 60
async def store_ws_token(token: str, user_id: str) -> None:
"""Store a short-lived WebSocket authentication token in Redis.
Args:
token: The generated WS token.
user_id: The user ID to associate with this token.
"""
redis = await get_async_redis_connection()
redis_key = REDIS_WS_TOKEN_PREFIX + token
token_data = json.dumps({"sub": user_id})
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
async def retrieve_ws_token_data(token: str) -> dict | None:
"""Validate a WebSocket token and return the token data.
Args:
token: The WS token to validate.
Returns:
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_WS_TOKEN_PREFIX + token
token_data_str = await redis.get(redis_key)
if not token_data_str:
return None
# Delete the token after retrieval (single-use)
await redis.delete(redis_key)
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding WS token data from Redis")
return None
except Exception as e:
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
return None
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:

View File

@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_from_websocket
from onyx.auth.users import current_user_with_expired_token
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -129,6 +130,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accessible_user
or depends_fn == current_user_from_websocket
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token

View File

@@ -85,6 +85,12 @@ class UserPreferences(BaseModel):
chat_background: str | None = None
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
# Voice preferences
voice_auto_send: bool | None = None
voice_auto_playback: bool | None = None
voice_playback_speed: float | None = None
preferred_voice: str | None = None
# controls which tools are enabled for the user for a specific assistant
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
@@ -164,6 +170,10 @@ class UserInfo(BaseModel):
theme_preference=user.theme_preference,
chat_background=user.chat_background,
default_app_mode=user.default_app_mode,
voice_auto_send=user.voice_auto_send,
voice_auto_playback=user.voice_auto_playback,
voice_playback_speed=user.voice_playback_speed,
preferred_voice=user.preferred_voice,
assistant_specific_configs=assistant_specific_configs,
)
),
@@ -240,6 +250,13 @@ class ChatBackgroundRequest(BaseModel):
chat_background: str | None
class VoiceSettingsUpdateRequest(BaseModel):
auto_send: bool | None = None
auto_playback: bool | None = None
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
preferred_voice: str | None = None
class PersonalizationUpdateRequest(BaseModel):
name: str | None = None
role: str | None = None

View File

@@ -0,0 +1,282 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import User
from onyx.db.models import VoiceProvider
from onyx.db.voice import deactivate_stt_provider
from onyx.db.voice import deactivate_tts_provider
from onyx.db.voice import delete_voice_provider
from onyx.db.voice import fetch_voice_provider_by_id
from onyx.db.voice import fetch_voice_provider_by_type
from onyx.db.voice import fetch_voice_providers
from onyx.db.voice import set_default_stt_provider
from onyx.db.voice import set_default_tts_provider
from onyx.db.voice import upsert_voice_provider
from onyx.server.manage.voice.models import VoiceProviderTestRequest
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
from onyx.server.manage.voice.models import VoiceProviderView
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/voice")
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
"""Convert a VoiceProvider model to a VoiceProviderView."""
return VoiceProviderView(
id=provider.id,
name=provider.name,
provider_type=provider.provider_type,
is_default_stt=provider.is_default_stt,
is_default_tts=provider.is_default_tts,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
has_api_key=bool(provider.api_key),
target_uri=provider.api_base, # api_base stores the target URI for Azure
)
@admin_router.get("/providers")
def list_voice_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VoiceProviderView]:
"""List all configured voice providers."""
providers = fetch_voice_providers(db_session)
return [_provider_to_view(provider) for provider in providers]
@admin_router.post("/providers")
def upsert_voice_provider_endpoint(
request: VoiceProviderUpsertRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Create or update a voice provider."""
api_key = request.api_key
api_key_changed = request.api_key_changed
# If llm_provider_id is specified, copy the API key from that LLM provider
if request.llm_provider_id is not None:
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
if llm_provider is None:
raise HTTPException(
status_code=404,
detail=f"LLM provider with id {request.llm_provider_id} not found.",
)
if llm_provider.api_key is None:
raise HTTPException(
status_code=400,
detail="Selected LLM provider has no API key configured.",
)
api_key = llm_provider.api_key.get_value(apply_mask=False)
api_key_changed = True
# Use target_uri if provided, otherwise fall back to api_base
api_base = request.target_uri or request.api_base
provider = upsert_voice_provider(
db_session=db_session,
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
api_base=api_base,
custom_config=request.custom_config,
stt_model=request.stt_model,
tts_model=request.tts_model,
default_voice=request.default_voice,
activate_stt=request.activate_stt,
activate_tts=request.activate_tts,
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.delete(
"/providers/{provider_id}", status_code=204, response_class=Response
)
def delete_voice_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
"""Delete a voice provider."""
delete_voice_provider(db_session, provider_id)
db_session.commit()
return Response(status_code=204)
@admin_router.post("/providers/{provider_id}/activate-stt")
def activate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default STT provider."""
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-stt")
def deactivate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default STT status from a voice provider."""
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/{provider_id}/activate-tts")
def activate_tts_provider_endpoint(
provider_id: int,
tts_model: str | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default TTS provider."""
provider = set_default_tts_provider(
db_session=db_session, provider_id=provider_id, tts_model=tts_model
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-tts")
def deactivate_tts_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default TTS status from a voice provider."""
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/test")
def test_voice_provider(
request: VoiceProviderTestRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Test a voice provider connection."""
api_key = request.api_key
if request.use_stored_key:
existing_provider = fetch_voice_provider_by_type(
db_session, request.provider_type
)
if existing_provider is None or not existing_provider.api_key:
raise HTTPException(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise HTTPException(
status_code=400,
detail="API key is required. Either provide api_key or set use_stored_key to true.",
)
# Use target_uri if provided, otherwise fall back to api_base
api_base = request.target_uri or request.api_base
# Create a temporary VoiceProvider for testing (not saved to DB)
temp_provider = VoiceProvider(
name="__test__",
provider_type=request.provider_type,
api_base=api_base,
custom_config=request.custom_config or {},
)
temp_provider.api_key = api_key # type: ignore[assignment]
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# Test the provider by getting available voices (lightweight check)
try:
voices = provider.get_available_voices()
if not voices:
raise HTTPException(
status_code=400,
detail="Provider returned no available voices.",
)
except NotImplementedError:
# Provider not fully implemented yet (Azure, ElevenLabs placeholders)
pass
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Connection test failed: {str(e)}",
) from e
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
return {"status": "ok"}
@admin_router.get("/providers/{provider_id}/voices")
def get_provider_voices(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[dict[str, str]]:
"""Get available voices for a provider."""
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
if provider_db is None:
raise HTTPException(status_code=404, detail="Voice provider not found.")
if not provider_db.api_key:
raise HTTPException(
status_code=400, detail="Provider has no API key configured."
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return provider.get_available_voices()
@admin_router.get("/voices")
def get_voices_by_type(
provider_type: str,
_: User = Depends(current_admin_user),
) -> list[dict[str, str]]:
"""Get available voices for a provider type.
For providers like ElevenLabs and OpenAI, this fetches voices
without requiring an existing provider configuration.
"""
# Create a temporary VoiceProvider to get static voice list
temp_provider = VoiceProvider(
name="__temp__",
provider_type=provider_type,
)
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return provider.get_available_voices()

View File

@@ -0,0 +1,90 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
class VoiceProviderView(BaseModel):
"""Response model for voice provider listing."""
id: int
name: str
provider_type: str # "openai", "azure", "elevenlabs"
is_default_stt: bool
is_default_tts: bool
stt_model: str | None
tts_model: str | None
default_voice: str | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services.",
)
class VoiceProviderUpsertRequest(BaseModel):
"""Request model for creating or updating a voice provider."""
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: str # "openai", "azure", "elevenlabs"
api_key: str | None = Field(
default=None,
description="API key for the provider.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
llm_provider_id: int | None = Field(
default=None,
description="If set, copies the API key from the specified LLM provider.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None
stt_model: str | None = None
tts_model: str | None = None
default_voice: str | None = None
activate_stt: bool = Field(
default=False,
description="If true, sets this provider as the default STT provider after upsert.",
)
activate_tts: bool = Field(
default=False,
description="If true, sets this provider as the default TTS provider after upsert.",
)
class VoiceProviderTestRequest(BaseModel):
"""Request model for testing a voice provider connection."""
provider_type: str
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None
class SynthesizeRequest(BaseModel):
"""Request model for text-to-speech synthesis."""
text: str = Field(..., min_length=1, max_length=4096)
voice: str | None = None
speed: float = Field(default=1.0, ge=0.5, le=2.0)

View File

@@ -0,0 +1,207 @@
import secrets
from collections.abc import AsyncIterator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.db.voice import update_user_voice_settings
from onyx.redis.redis_pool import store_ws_token
from onyx.server.manage.models import VoiceSettingsUpdateRequest
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Max audio file size: 25MB (Whisper limit)
MAX_AUDIO_SIZE = 25 * 1024 * 1024
@router.post("/transcribe")
async def transcribe_audio(
audio: UploadFile = File(...),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Transcribe audio to text using the default STT provider."""
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
raise HTTPException(
status_code=400,
detail="No speech-to-text provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
audio_data = await audio.read()
if len(audio_data) > MAX_AUDIO_SIZE:
raise HTTPException(
status_code=400,
detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
)
# Extract format from filename
filename = audio.filename or "audio.webm"
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
try:
text = await provider.transcribe(audio_data, audio_format)
return {"text": text}
except NotImplementedError as exc:
raise HTTPException(
status_code=501,
detail=f"Speech-to-text not implemented for {provider_db.provider_type}.",
) from exc
except Exception as exc:
logger.error(f"Transcription failed: {exc}")
raise HTTPException(
status_code=500,
detail=f"Transcription failed: {str(exc)}",
) from exc
@router.post("/synthesize")
async def synthesize_speech(
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
user: User = Depends(current_user),
) -> StreamingResponse:
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise HTTPException(status_code=400, detail="Text is required")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.error("No TTS provider configured")
raise HTTPException(
status_code=400,
detail="No text-to-speech provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
logger.error("TTS provider has no API key")
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
# Use request voice, or user's preferred voice, or provider default
final_voice = voice or user.preferred_voice or provider_db.default_voice
final_speed = speed or user.voice_playback_speed or 1.0
logger.info(
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
logger.error(f"Failed to get voice provider: {exc}")
raise HTTPException(status_code=500, detail=str(exc)) from exc
# Session is now closed - streaming response won't hold DB connection
async def audio_stream() -> AsyncIterator[bytes]:
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@router.patch("/settings")
def update_voice_settings(
request: VoiceSettingsUpdateRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Update user's voice settings."""
update_user_voice_settings(
db_session=db_session,
user_id=user.id,
auto_send=request.auto_send,
auto_playback=request.auto_playback,
playback_speed=request.playback_speed,
preferred_voice=request.preferred_voice,
)
return {"status": "ok"}
class WSTokenResponse(BaseModel):
token: str
@router.post("/ws-token")
async def get_ws_token(
user: User = Depends(current_user),
) -> WSTokenResponse:
"""
Generate a short-lived token for WebSocket authentication.
This token should be passed as a query parameter when connecting
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
The token expires after 60 seconds and is single-use.
"""
token = secrets.token_urlsafe(32)
await store_ws_token(token, str(user.id))
return WSTokenResponse(token=token)

View File

@@ -0,0 +1,778 @@
"""WebSocket API for streaming speech-to-text and text-to-speech."""
import asyncio
import io
import json
import os
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from sqlalchemy.orm import Session
from onyx.auth.users import current_user_from_websocket
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
MIN_CHUNK_BYTES = 1500
VOICE_DISABLE_STREAMING_FALLBACK = (
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
)
class ChunkedTranscriber:
"""Fallback transcriber for providers without streaming support."""
def __init__(self, provider: Any, audio_format: str = "webm"):
self.provider = provider
self.audio_format = audio_format
self.chunk_buffer = io.BytesIO()
self.full_audio = io.BytesIO()
self.chunk_bytes = 0
self.transcripts: list[str] = []
async def add_chunk(self, chunk: bytes) -> str | None:
"""Add audio chunk. Returns transcript if enough audio accumulated."""
self.chunk_buffer.write(chunk)
self.full_audio.write(chunk)
self.chunk_bytes += len(chunk)
if self.chunk_bytes >= MIN_CHUNK_BYTES:
return await self._transcribe_chunk()
return None
async def _transcribe_chunk(self) -> str | None:
"""Transcribe current chunk and append to running transcript."""
audio_data = self.chunk_buffer.getvalue()
if not audio_data:
return None
try:
transcript = await self.provider.transcribe(audio_data, self.audio_format)
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
if transcript and transcript.strip():
self.transcripts.append(transcript.strip())
return " ".join(self.transcripts)
return None
except Exception as e:
logger.error(f"Transcription error: {e}")
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
return None
async def flush(self) -> str:
"""Get final transcript from full audio for best accuracy."""
full_audio_data = self.full_audio.getvalue()
if full_audio_data:
try:
transcript = await self.provider.transcribe(
full_audio_data, self.audio_format
)
if transcript and transcript.strip():
return transcript.strip()
except Exception as e:
logger.error(f"Final transcription error: {e}")
return " ".join(self.transcripts)
async def handle_streaming_transcription(
websocket: WebSocket,
transcriber: StreamingTranscriberProtocol,
) -> None:
"""Handle transcription using native streaming API."""
logger.info("Streaming transcription: starting handler")
last_transcript = ""
chunk_count = 0
total_bytes = 0
async def receive_transcripts() -> None:
"""Background task to receive and send transcripts."""
nonlocal last_transcript
logger.info("Streaming transcription: starting transcript receiver")
while True:
result: TranscriptResult | None = await transcriber.receive_transcript()
if result is None: # End of stream
logger.info("Streaming transcription: transcript stream ended")
break
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
if result.text and (result.text != last_transcript or result.is_vad_end):
last_transcript = result.text
logger.info(
f"Streaming transcription: got transcript: {result.text[:50]}... "
f"(is_vad_end={result.is_vad_end})"
)
await websocket.send_json(
{
"type": "transcript",
"text": result.text,
"is_final": result.is_vad_end,
}
)
# Start receiving transcripts in background
receive_task = asyncio.create_task(receive_transcripts())
try:
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
await transcriber.send_audio(message["bytes"])
elif "text" in message:
try:
data = json.loads(message["text"])
logger.info(
f"Streaming transcription: received text message: {data}"
)
if data.get("type") == "end":
logger.info(
"Streaming transcription: end signal received, closing transcriber"
)
final_transcript = await transcriber.close()
receive_task.cancel()
logger.info(
"Streaming transcription: final transcript: "
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
elif data.get("type") == "reset":
# Reset accumulated transcript after auto-send
logger.info(
"Streaming transcription: reset signal received, clearing transcript"
)
transcriber.reset_transcript()
except json.JSONDecodeError:
logger.warning(
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
except Exception as e:
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
raise
finally:
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
logger.info(
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
async def handle_chunked_transcription(
websocket: WebSocket,
transcriber: ChunkedTranscriber,
) -> None:
"""Handle transcription using chunked batch API."""
logger.info("Chunked transcription: starting handler")
chunk_count = 0
total_bytes = 0
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
transcript = await transcriber.add_chunk(message["bytes"])
if transcript:
logger.info(
f"Chunked transcription: got transcript: {transcript[:50]}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": transcript,
"is_final": False,
}
)
elif "text" in message:
try:
data = json.loads(message["text"])
logger.info(f"Chunked transcription: received text message: {data}")
if data.get("type") == "end":
logger.info("Chunked transcription: end signal received, flushing")
final_transcript = await transcriber.flush()
logger.info(
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
except json.JSONDecodeError:
logger.warning(
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
logger.info(
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
@router.websocket("/transcribe/stream")
async def websocket_transcribe(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming speech-to-text.
Protocol:
- Client sends binary audio chunks
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
- Client sends JSON {"type": "end"} to signal end
- Server responds with final transcript and closes
Authentication:
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket transcribe: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket transcribe: connection accepted")
except Exception as e:
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
return
streaming_transcriber = None
provider = None
try:
# Get STT provider
logger.info("WebSocket transcribe: fetching STT provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket transcribe: no default STT provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No speech-to-text provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket transcribe: STT provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Speech-to-text provider has no API key configured",
}
)
return
logger.info(
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
)
except ValueError as e:
logger.error(
f"WebSocket transcribe: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_stt():
logger.info("WebSocket transcribe: using native streaming STT")
try:
streaming_transcriber = await provider.create_streaming_transcriber()
logger.info(
"WebSocket transcribe: streaming transcriber created successfully"
)
await handle_streaming_transcription(websocket, streaming_transcriber)
except Exception as e:
logger.error(
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming STT failed: {e}"}
)
return
logger.info("WebSocket transcribe: falling back to chunked STT")
# Browser stream provides raw PCM16 chunks over WebSocket.
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
else:
# Fall back to chunked transcription
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming STT",
}
)
return
logger.info(
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
)
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
except WebSocketDisconnect:
logger.debug("WebSocket transcribe: client disconnected")
except Exception as e:
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass
finally:
if streaming_transcriber:
try:
await streaming_transcriber.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket transcribe: connection closed")
async def handle_streaming_synthesis(
websocket: WebSocket,
synthesizer: StreamingSynthesizerProtocol,
) -> None:
"""Handle TTS using native streaming API.
Buffers all text, then sends to ElevenLabs when complete.
This is more reliable than streaming chunks incrementally.
"""
logger.info("Streaming synthesis: starting handler")
async def send_audio() -> None:
"""Background task to send audio chunks to client."""
chunk_count = 0
total_bytes = 0
try:
while True:
audio_chunk = await synthesizer.receive_audio()
if audio_chunk is None:
logger.info(
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
)
try:
await websocket.send_json({"type": "audio_done"})
logger.info("Streaming synthesis: sent audio_done to client")
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send audio_done: {e}"
)
break
if audio_chunk: # Skip empty chunks
chunk_count += 1
total_bytes += len(audio_chunk)
try:
await websocket.send_bytes(audio_chunk)
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send chunk: {e}"
)
break
except asyncio.CancelledError:
logger.info(
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
)
except Exception as e:
logger.error(f"Streaming synthesis: send_audio error: {e}")
send_task: asyncio.Task | None = None
text_buffer: list[str] = [] # Buffer text chunks until "end"
disconnected = False
try:
while not disconnected:
try:
message = await websocket.receive()
except RuntimeError as e:
if "disconnect" in str(e).lower():
logger.info("Streaming synthesis: client disconnected")
break
raise
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
if msg_type == "websocket.disconnect":
logger.info("Streaming synthesis: client disconnected")
disconnected = True
break
if "text" in message:
try:
data = json.loads(message["text"])
if data.get("type") == "synthesize":
text = data.get("text", "")
if not text:
for key, value in data.items():
if key != "type" and isinstance(value, str) and value:
text = value
break
if text:
# Buffer text instead of sending immediately
text_buffer.append(text)
logger.info(
f"Streaming synthesis: buffered text ({len(text)} chars), "
f"total buffered: {len(text_buffer)} chunks"
)
elif data.get("type") == "end":
logger.info("Streaming synthesis: end signal received")
if text_buffer:
# Combine all buffered text
full_text = " ".join(text_buffer)
logger.info(
f"Streaming synthesis: sending full text ({len(full_text)} chars): "
f"'{full_text[:100]}...'"
)
# Start audio receiver
send_task = asyncio.create_task(send_audio())
# Send all text at once
await synthesizer.send_text(full_text)
logger.info(
"Streaming synthesis: full text sent to synthesizer"
)
# Signal end of input
if hasattr(synthesizer, "flush"):
await synthesizer.flush()
# Wait for all audio to be sent
logger.info(
"Streaming synthesis: waiting for audio stream to complete"
)
try:
await asyncio.wait_for(send_task, timeout=60.0)
except asyncio.TimeoutError:
logger.warning(
"Streaming synthesis: timeout waiting for audio"
)
break
except json.JSONDecodeError:
logger.warning(
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
)
except Exception as e:
if "disconnect" not in str(e).lower():
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
finally:
if send_task and not send_task.done():
logger.info("Streaming synthesis: waiting for send_task to finish")
try:
await asyncio.wait_for(send_task, timeout=30.0)
except asyncio.TimeoutError:
logger.warning("Streaming synthesis: timeout waiting for send_task")
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
logger.info("Streaming synthesis: handler finished")
async def handle_chunked_synthesis(websocket: WebSocket, provider: Any) -> None:
"""Fallback TTS handler using provider.synthesize_stream."""
logger.info("Chunked synthesis: starting handler")
text_buffer: list[str] = []
voice: str | None = None
speed = 1.0
try:
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info("Chunked synthesis: client disconnected")
break
if "text" not in message:
continue
try:
data = json.loads(message["text"])
except json.JSONDecodeError:
logger.warning(
"Chunked synthesis: failed to parse JSON: "
f"{message.get('text', '')[:100]}"
)
continue
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
if msg_data_type == "synthesize":
text = data.get("text", "")
if not text:
for key, value in data.items():
if key != "type" and isinstance(value, str) and value:
text = value
break
if text:
text_buffer.append(text)
logger.info(
f"Chunked synthesis: buffered text ({len(text)} chars), "
f"total buffered: {len(text_buffer)} chunks"
)
if isinstance(data.get("voice"), str) and data["voice"]:
voice = data["voice"]
if isinstance(data.get("speed"), (int, float)):
speed = float(data["speed"])
elif msg_data_type == "end":
logger.info("Chunked synthesis: end signal received")
full_text = " ".join(text_buffer).strip()
if not full_text:
await websocket.send_json({"type": "audio_done"})
logger.info("Chunked synthesis: no text, sent audio_done")
break
chunk_count = 0
total_bytes = 0
logger.info(
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
)
async for audio_chunk in provider.synthesize_stream(
full_text, voice=voice, speed=speed
):
if not audio_chunk:
continue
chunk_count += 1
total_bytes += len(audio_chunk)
await websocket.send_bytes(audio_chunk)
await websocket.send_json({"type": "audio_done"})
logger.info(
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
)
break
except Exception as e:
if "disconnect" not in str(e).lower():
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
raise
finally:
logger.info("Chunked synthesis: handler finished")
@router.websocket("/synthesize/stream")
async def websocket_synthesize(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming text-to-speech.
Protocol:
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
- Server sends binary audio chunks
- Server sends JSON: {"type": "audio_done"} when synthesis completes
- Client sends JSON {"type": "end"} to close connection
Authentication:
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket synthesize: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket synthesize: connection accepted")
except Exception as e:
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
return
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
provider = None
try:
# Get TTS provider
logger.info("WebSocket synthesize: fetching TTS provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket synthesize: no default TTS provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No text-to-speech provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket synthesize: TTS provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Text-to-speech provider has no API key configured",
}
)
return
logger.info(
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
)
except ValueError as e:
logger.error(
f"WebSocket synthesize: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_tts():
logger.info("WebSocket synthesize: using native streaming TTS")
try:
# Wait for initial config message with voice/speed
message = await websocket.receive()
voice = None
speed = 1.0
if "text" in message:
try:
data = json.loads(message["text"])
voice = data.get("voice")
speed = data.get("speed", 1.0)
except json.JSONDecodeError:
pass
streaming_synthesizer = await provider.create_streaming_synthesizer(
voice=voice, speed=speed
)
logger.info(
"WebSocket synthesize: streaming synthesizer created successfully"
)
await handle_streaming_synthesis(websocket, streaming_synthesizer)
except Exception as e:
logger.error(
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming TTS failed: {e}"}
)
return
logger.info(
"WebSocket synthesize: falling back to chunked TTS synthesis"
)
await handle_chunked_synthesis(websocket, provider)
else:
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming TTS",
}
)
return
logger.info(
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
)
await handle_chunked_synthesis(websocket, provider)
except WebSocketDisconnect:
logger.debug("WebSocket synthesize: client disconnected")
except RuntimeError as e:
if "disconnect" in str(e).lower():
logger.debug("WebSocket synthesize: client disconnected")
else:
logger.error(f"WebSocket synthesize: runtime error: {e}")
except Exception as e:
error_str = str(e).lower()
if "disconnect" in error_str or "websocket.close" in error_str:
logger.debug("WebSocket synthesize: client disconnected")
else:
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
try:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception:
pass
finally:
if streaming_synthesizer:
try:
await streaming_synthesizer.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket synthesize: connection closed")

View File

View File

@@ -0,0 +1,70 @@
from onyx.db.models import VoiceProvider
from onyx.voice.interface import VoiceProviderInterface
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
"""
Factory function to get the appropriate voice provider implementation.
Args:
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
Returns:
VoiceProviderInterface implementation
Raises:
ValueError: If provider_type is not supported
"""
provider_type = provider.provider_type.lower()
# Handle both SensitiveValue (from DB) and plain string (from temp model)
if provider.api_key is None:
api_key = None
elif hasattr(provider.api_key, "get_value"):
# SensitiveValue from database
api_key = provider.api_key.get_value(apply_mask=False)
else:
# Plain string from temporary model
api_key = provider.api_key # type: ignore[assignment]
api_base = provider.api_base
custom_config = provider.custom_config
stt_model = provider.stt_model
tts_model = provider.tts_model
default_voice = provider.default_voice
if provider_type == "openai":
from onyx.voice.providers.openai import OpenAIVoiceProvider
return OpenAIVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "azure":
from onyx.voice.providers.azure import AzureVoiceProvider
return AzureVoiceProvider(
api_key=api_key,
api_base=api_base,
custom_config=custom_config or {},
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "elevenlabs":
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
return ElevenLabsVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
else:
raise ValueError(f"Unsupported voice provider type: {provider_type}")

View File

@@ -0,0 +1,175 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Protocol
@dataclass
class TranscriptResult:
"""Result from streaming transcription."""
text: str
"""The accumulated transcript text."""
is_vad_end: bool = False
"""True if VAD detected end of speech (silence). Use for auto-send."""
class StreamingTranscriberProtocol(Protocol):
"""Protocol for streaming transcription sessions."""
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
...
async def receive_transcript(self) -> TranscriptResult | None:
"""
Receive next transcript update.
Returns:
TranscriptResult with accumulated text and VAD status, or None when stream ends.
"""
...
async def close(self) -> str:
"""Close the session and return final transcript."""
...
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
...
class StreamingSynthesizerProtocol(Protocol):
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
async def connect(self) -> None:
"""Establish connection to TTS provider."""
...
async def send_text(self, text: str) -> None:
"""Send text to be synthesized."""
...
async def receive_audio(self) -> bytes | None:
"""
Receive next audio chunk.
Returns:
Audio bytes, or None when stream ends.
"""
...
async def flush(self) -> None:
"""Signal end of text input and wait for pending audio."""
...
async def close(self) -> None:
"""Close the session."""
...
class VoiceProviderInterface(ABC):
"""Abstract base class for voice providers (STT and TTS)."""
@abstractmethod
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Convert audio to text (Speech-to-Text).
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
@abstractmethod
def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio stream (Text-to-Speech).
Streams audio chunks progressively for lower latency playback.
Args:
text: Text to convert to speech
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks
"""
@abstractmethod
def get_available_voices(self) -> list[dict[str, str]]:
"""
Get list of available voices for this provider.
Returns:
List of voice dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_stt_models(self) -> list[dict[str, str]]:
"""
Get list of available STT models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_tts_models(self) -> list[dict[str, str]]:
"""
Get list of available TTS models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
def supports_streaming_stt(self) -> bool:
"""Returns True if this provider supports streaming STT."""
return False
def supports_streaming_tts(self) -> bool:
"""Returns True if this provider supports real-time streaming TTS."""
return False
async def create_streaming_transcriber(
self, audio_format: str = "webm"
) -> StreamingTranscriberProtocol:
"""
Create a streaming transcription session.
Args:
audio_format: Audio format being sent (e.g., "webm", "pcm16")
Returns:
A streaming transcriber that can send audio chunks and receive transcripts
Raises:
NotImplementedError: If streaming STT is not supported
"""
raise NotImplementedError("Streaming STT not supported by this provider")
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> "StreamingSynthesizerProtocol":
"""
Create a streaming TTS session for real-time audio synthesis.
Args:
voice: Voice identifier
speed: Playback speed multiplier
Returns:
A streaming synthesizer that can send text and receive audio chunks
Raises:
NotImplementedError: If streaming TTS is not supported
"""
raise NotImplementedError("Streaming TTS not supported by this provider")

View File

View File

@@ -0,0 +1,471 @@
import asyncio
import io
import re
import struct
import wave
from collections.abc import AsyncIterator
from typing import Any
from xml.sax.saxutils import escape
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# Common Azure Neural voices
AZURE_VOICES = [
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
]
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using Azure Speech SDK."""
def __init__(
self,
api_key: str,
region: str,
input_sample_rate: int = 24000,
target_sample_rate: int = 16000,
):
self.api_key = api_key
self.region = region
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._accumulated_transcript = ""
self._recognizer: Any = None
self._audio_stream: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech recognizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming STT. "
"Install `azure-cognitiveservices-speech`."
) from e
self._loop = asyncio.get_running_loop()
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
audio_format = speechsdk.audio.AudioStreamFormat(
samples_per_second=16000,
bits_per_sample=16,
channels=1,
)
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
self._recognizer = speechsdk.SpeechRecognizer(
speech_config=speech_config,
audio_config=audio_config,
)
transcriber = self
def on_recognizing(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
full_text = transcriber._accumulated_transcript
if full_text:
full_text += " " + evt.result.text
else:
full_text = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(text=full_text, is_vad_end=False),
)
def on_recognized(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
if transcriber._accumulated_transcript:
transcriber._accumulated_transcript += " " + evt.result.text
else:
transcriber._accumulated_transcript = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(
text=transcriber._accumulated_transcript, is_vad_end=True
),
)
self._recognizer.recognizing.connect(on_recognizing)
self._recognizer.recognized.connect(on_recognized)
self._recognizer.start_continuous_recognition_async()
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to Azure."""
if self._audio_stream and not self._closed:
self._audio_stream.write(self._resample_pcm16(chunk))
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
if self.input_sample_rate == self.target_sample_rate:
return data
num_samples = len(data) // 2
if num_samples == 0:
return b""
samples = list(struct.unpack(f"<{num_samples}h", data))
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
resampled: list[int] = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Stop recognition and return final transcript."""
self._closed = True
if self._recognizer:
self._recognizer.stop_continuous_recognition_async()
if self._audio_stream:
self._audio_stream.close()
self._loop = None
return self._accumulated_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript."""
self._accumulated_transcript = ""
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using Azure Speech SDK."""
def __init__(self, api_key: str, region: str, voice: str = "en-US-JennyNeural"):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.region = region
self.voice = voice
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._synthesizer: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech synthesizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming TTS. "
"Install `azure-cognitiveservices-speech`."
) from e
self._logger.info("AzureStreamingSynthesizer: connecting")
# Store the event loop for thread-safe queue operations
self._loop = asyncio.get_running_loop()
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
speech_config.speech_synthesis_voice_name = self.voice
# Use MP3 format for streaming - compatible with MediaSource Extensions
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
)
# Create synthesizer with pull audio output stream
self._synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config,
audio_config=None, # We'll manually handle audio
)
# Connect to synthesis events
self._synthesizer.synthesizing.connect(self._on_synthesizing)
self._synthesizer.synthesis_completed.connect(self._on_completed)
self._logger.info("AzureStreamingSynthesizer: connected")
def _on_synthesizing(self, evt: Any) -> None:
"""Called when audio chunk is available (runs in Azure SDK thread)."""
if evt.result.audio_data and self._loop and not self._closed:
# Thread-safe way to put item in async queue
self._loop.call_soon_threadsafe(
self._audio_queue.put_nowait, evt.result.audio_data
)
def _on_completed(self, _evt: Any) -> None:
"""Called when synthesis is complete (runs in Azure SDK thread)."""
if self._loop and not self._closed:
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
async def send_text(self, text: str) -> None:
"""Send text to be synthesized."""
if self._synthesizer and not self._closed:
# Start synthesis asynchronously
self._synthesizer.speak_text_async(text)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for pending audio."""
# Azure SDK handles flushing automatically
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._synthesizer:
self._synthesizer.synthesis_completed.disconnect_all()
self._synthesizer.synthesizing.disconnect_all()
self._loop = None
class AzureVoiceProvider(VoiceProviderInterface):
"""Azure Speech Services voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None,
custom_config: dict[str, Any],
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.custom_config = custom_config
self.speech_region = (
custom_config.get("speech_region")
or self._extract_speech_region_from_uri(api_base)
or ""
)
self.stt_model = stt_model
self.tts_model = tts_model
self.default_voice = default_voice or "en-US-JennyNeural"
@staticmethod
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
"""Extract Azure speech region from endpoint URI."""
if not uri:
return None
# Accepted examples:
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
# - https://westus.api.cognitive.microsoft.com/
# - https://<resource>.cognitiveservices.azure.com/ (fallback to first label)
patterns = [
r"https?://([^.]+)\.(?:tts|stt)\.speech\.microsoft\.com",
r"https?://([^.]+)\.api\.cognitive\.microsoft\.com",
r"https?://([^.]+)\.cognitiveservices\.azure\.com",
]
for pattern in patterns:
match = re.search(pattern, uri)
if match:
return match.group(1)
return None
@staticmethod
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
"""Wrap raw PCM16 mono bytes into a WAV container."""
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
return buffer.getvalue()
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
if not self.api_key:
raise ValueError("Azure API key required for STT")
if not self.speech_region:
raise ValueError("Azure speech region required for STT")
normalized_format = audio_format.lower()
payload = audio_data
content_type = f"audio/{normalized_format}"
# WebSocket chunked fallback sends raw PCM16 bytes.
if normalized_format in {"pcm", "pcm16", "raw"}:
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
content_type = "audio/wav"
elif normalized_format in {"wav", "wave"}:
content_type = "audio/wav"
elif normalized_format == "webm":
content_type = "audio/webm; codecs=opus"
url = (
f"https://{self.speech_region}.stt.speech.microsoft.com/"
"speech/recognition/conversation/cognitiveservices/v1"
)
params = {"language": "en-US", "format": "detailed"}
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": content_type,
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(
url, params=params, headers=headers, data=payload
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure STT failed: {error_text}")
result = await response.json()
if result.get("RecognitionStatus") != "Success":
return ""
nbest = result.get("NBest") or []
if nbest and isinstance(nbest, list):
display = nbest[0].get("Display")
if isinstance(display, str):
return display
display_text = result.get("DisplayText", "")
return display_text if isinstance(display_text, str) else ""
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using Azure TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice name (defaults to provider's default voice)
speed: Playback speed multiplier (0.5 to 2.0)
Yields:
Audio data chunks (mp3 format)
"""
if not self.api_key:
raise ValueError("Azure API key required for TTS")
if not self.speech_region:
raise ValueError("Azure speech region required for TTS")
voice_name = voice or self.default_voice
# Clamp speed to valid range and convert to rate format
speed = max(0.5, min(2.0, speed))
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
# Build SSML with escaped text to prevent injection
escaped_text = escape(text)
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
<voice name='{voice_name}'>
<prosody rate='{rate}'>{escaped_text}</prosody>
</voice>
</speak>"""
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
"User-Agent": "Onyx",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=ssml) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
async for chunk in response.content.iter_chunked(8192):
if chunk:
yield chunk
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common Azure Neural voices."""
return AZURE_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "default", "name": "Azure Speech Recognition"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "neural", "name": "Neural TTS"},
]
def supports_streaming_stt(self) -> bool:
"""Azure supports streaming STT via Speech SDK."""
return True
def supports_streaming_tts(self) -> bool:
"""Azure supports real-time streaming TTS via Speech SDK."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> AzureStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
if not self.speech_region:
raise ValueError("Speech region required for Azure streaming transcription")
transcriber = AzureStreamingTranscriber(
api_key=self.api_key,
region=self.speech_region,
input_sample_rate=24000,
target_sample_rate=16000,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> AzureStreamingSynthesizer:
"""Create a streaming TTS session."""
_ = speed # Azure SDK streaming path does not currently support runtime speed control.
if not self.api_key:
raise ValueError("API key required for streaming TTS")
if not self.speech_region:
raise ValueError("Speech region required for Azure streaming TTS")
synthesizer = AzureStreamingSynthesizer(
api_key=self.api_key,
region=self.speech_region,
voice=voice or self.default_voice or "en-US-JennyNeural",
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,751 @@
import asyncio
import base64
import json
from collections.abc import AsyncIterator
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# Common ElevenLabs voices
ELEVENLABS_VOICES = [
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
]
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
def __init__(
self,
api_key: str,
model: str = "scribe_v2_realtime",
input_sample_rate: int = 24000, # What frontend sends
target_sample_rate: int = 16000, # What ElevenLabs expects
language_code: str = "en",
):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self.language_code = language_code
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._final_transcript = ""
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs."""
self._logger.info(
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
)
self._session = aiohttp.ClientSession()
# VAD is configured via query parameters
# commit_strategy=vad enables automatic transcript commit on silence detection
url = (
f"wss://api.elevenlabs.io/v1/speech-to-text/realtime"
f"?model_id={self.model}"
f"&sample_rate={self.target_sample_rate}"
f"&language_code={self.language_code}"
f"&commit_strategy=vad"
f"&vad_silence_threshold_secs=1.0"
f"&vad_threshold=0.4"
f"&min_speech_duration_ms=100"
f"&min_silence_duration_ms=300"
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connecting to {url} "
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
)
try:
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connected successfully, "
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
)
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
)
if self._session:
await self._session.close()
raise
# Start receiving transcripts in background
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts from WebSocket."""
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
if not self._ws:
self._logger.warning(
"ElevenLabsStreamingTranscriber: no WebSocket connection"
)
return
try:
async for msg in self._ws:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
)
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data)
except json.JSONDecodeError:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
)
continue
# ElevenLabs uses message_type field
msg_type = data.get("message_type", data.get("type", "")) # type: ignore[possibly-undefined]
self._logger.info(
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
)
# Check for error in various formats
if "error" in data or msg_type == "error":
error_msg = data.get("error", data.get("message", data))
self._logger.error(
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
)
continue
# Handle different message types from ElevenLabs Scribe API
if msg_type == "session_started":
# Session started successfully
self._logger.info(
f"ElevenLabsStreamingTranscriber: session started, "
f"id={data.get('session_id')}, config={data.get('config')}"
)
elif msg_type == "partial_transcript":
# Partial transcript (interim result)
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=False)
)
elif msg_type == "committed_transcript":
# Final/committed transcript (VAD detected end of utterance)
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == "utterance_end":
# VAD detected end of speech
text = data.get("text", "") or self._final_transcript
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == "session_ended":
self._logger.info(
"ElevenLabsStreamingTranscriber: session ended"
)
break
else:
# Log unhandled message types with full data for debugging
self._logger.warning(
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
)
elif msg.type == aiohttp.WSMsgType.CLOSED:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
"ElevenLabsStreamingTranscriber: WebSocket closed by "
f"server, close_code={close_code}"
)
break
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
)
break
elif msg.type == aiohttp.WSMsgType.CLOSE:
self._logger.info(
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
)
break
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
exc_info=True,
)
finally:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
)
await self._transcript_queue.put(None) # Signal end
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
import struct
if self.input_sample_rate == self.target_sample_rate:
return data
# Parse int16 samples
num_samples = len(data) // 2
samples = list(struct.unpack(f"<{num_samples}h", data))
# Calculate resampling ratio
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
# Linear interpolation resampling
resampled = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
# Clamp to int16 range
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
if not self._ws:
self._logger.warning("send_audio: no WebSocket connection")
return
if self._closed:
self._logger.warning("send_audio: transcriber is closed")
return
if self._ws.closed:
self._logger.warning(
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
)
return
try:
# Resample from input rate (24kHz) to target rate (16kHz)
resampled = self._resample_pcm16(chunk)
# ElevenLabs expects input_audio_chunk message format with audio_base_64
audio_b64 = base64.b64encode(resampled).decode("utf-8")
message = {
"message_type": "input_audio_chunk",
"audio_base_64": audio_b64,
"sample_rate": self.target_sample_rate,
}
self._logger.info(
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
)
await self._ws.send_str(json.dumps(message))
self._logger.info("send_audio: message sent successfully")
except Exception as e:
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript. Returns None when done."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(
text="", is_vad_end=False
) # No transcript yet, but not done
async def close(self) -> str:
"""Close the session and return final transcript."""
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
self._closed = True
if self._ws and not self._ws.closed:
try:
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
self._logger.info(
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
)
await self._ws.close()
except Exception as e:
self._logger.debug(f"Error closing WebSocket: {e}")
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session and not self._session.closed:
await self._session.close()
return self._final_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._final_transcript = ""
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using ElevenLabs WebSocket API.
Uses ElevenLabs' stream-input WebSocket which processes text as one
continuous stream and returns audio in order.
"""
def __init__(
self,
api_key: str,
voice_id: str,
model_id: str = "eleven_multilingual_v2",
output_format: str = "mp3_44100_64",
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice_id = voice_id
self.model_id = model_id
self.output_format = output_format
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs TTS."""
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# WebSocket URL for streaming input TTS with output format for streaming compatibility
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
url = (
f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream-input"
f"?model_id={self.model_id}&output_format={self.output_format}"
)
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
# Send initial configuration with generation settings optimized for streaming
await self._ws.send_str(
json.dumps(
{
"text": " ", # Initial space to start the stream
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
"generation_config": {
"chunk_length_schedule": [
120,
160,
250,
290,
], # Optimized chunk sizes for streaming
},
"xi_api_key": self.api_key,
}
)
)
# Start receiving audio in background
self._receive_task = asyncio.create_task(self._receive_loop())
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
async def _receive_loop(self) -> None:
"""Background task to receive audio chunks from WebSocket.
Audio is returned in order as one continuous stream.
"""
if not self._ws:
return
chunk_count = 0
total_bytes = 0
try:
async for msg in self._ws:
if self._closed:
self._logger.info(
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
"receive loop"
)
break
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
# Process audio if present
if "audio" in data and data["audio"]:
audio_bytes = base64.b64decode(data["audio"])
chunk_count += 1
total_bytes += len(audio_bytes)
await self._audio_queue.put(audio_bytes)
# Check isFinal separately - a message can have both audio AND isFinal
if "isFinal" in data:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
)
if data.get("isFinal"):
self._logger.info(
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
)
await self._audio_queue.put(None)
# Check for errors
if "error" in data or data.get("type") == "error":
self._logger.error(
f"ElevenLabsStreamingSynthesizer: received error: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
chunk_count += 1
total_bytes += len(msg.data)
await self._audio_queue.put(msg.data)
elif msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
):
self._logger.info(
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
)
break
except Exception as e:
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
finally:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
)
await self._audio_queue.put(None) # Signal end of stream
async def send_text(self, text: str) -> None:
"""Send text to be synthesized.
ElevenLabs processes text as a continuous stream and returns
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
and only force generation when flush() is called at the end.
Args:
text: Text to synthesize
"""
if self._ws and not self._closed and text.strip():
self._logger.info(
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
)
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
# Don't trigger generation here - wait for flush() at the end
await self._ws.send_str(
json.dumps(
{
"text": text + " ", # Space for natural speech flow
}
)
)
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
if self._ws and not self._closed:
# Send empty string to signal end of input
# ElevenLabs will generate any remaining buffered text,
# send all audio chunks, send isFinal, then close the connection
self._logger.info(
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
)
await self._ws.send_str(json.dumps({"text": ""}))
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping flush - "
f"ws={self._ws is not None}, closed={self._closed}"
)
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._ws:
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
# Valid ElevenLabs model IDs
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
ELEVENLABS_TTS_MODELS = {
"eleven_multilingual_v2",
"eleven_turbo_v2_5",
"eleven_monolingual_v1",
"eleven_flash_v2_5",
"eleven_flash_v2",
}
class ElevenLabsVoiceProvider(VoiceProviderInterface):
"""ElevenLabs voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base or "https://api.elevenlabs.io"
# Validate and default models - use valid ElevenLabs model IDs
self.stt_model = (
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
)
self.tts_model = (
tts_model
if tts_model in ELEVENLABS_TTS_MODELS
else "eleven_multilingual_v2"
)
self.default_voice = default_voice
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using ElevenLabs Speech-to-Text API.
Args:
audio_data: Raw audio bytes
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
Returns:
Transcribed text
"""
if not self.api_key:
raise ValueError("ElevenLabs API key required for transcription")
from onyx.utils.logger import setup_logger
logger = setup_logger()
url = f"{self.api_base}/v1/speech-to-text"
# Map common formats to MIME types
mime_types = {
"webm": "audio/webm",
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"m4a": "audio/mp4",
}
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
headers = {
"xi-api-key": self.api_key,
}
# ElevenLabs expects multipart form data
form_data = aiohttp.FormData()
form_data.add_field(
"audio",
audio_data,
filename=f"audio.{audio_format}",
content_type=mime_type,
)
# For batch STT, use scribe_v1 (not the realtime model)
batch_model = (
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
)
form_data.add_field("model_id", batch_model)
logger.info(
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
)
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=form_data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs transcribe failed: {error_text}")
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
result = await response.json()
text = result.get("text", "")
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
return text
async def synthesize_stream(
self, text: str, voice: str | None = None, _speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using ElevenLabs TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice ID (defaults to provider's default voice or Rachel)
speed: Playback speed multiplier (not directly supported, ignored)
Yields:
Audio data chunks (mp3 format)
"""
from onyx.utils.logger import setup_logger
logger = setup_logger()
if not self.api_key:
raise ValueError("ElevenLabs API key required for TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
logger.info(
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
f"voice={voice_id}, model={self.tts_model}"
)
headers = {
"xi-api-key": self.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
}
payload = {
"text": text,
"model_id": self.tts_model,
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
logger.info(
f"ElevenLabs TTS: got response status={response.status}, "
f"content-type={response.headers.get('content-type')}"
)
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs TTS failed: {error_text}")
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
chunk_count = 0
total_bytes = 0
async for chunk in response.content.iter_chunked(8192):
if chunk:
chunk_count += 1
total_bytes += len(chunk)
yield chunk
logger.info(
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
f"{total_bytes} total bytes"
)
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common ElevenLabs voices."""
return ELEVENLABS_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
]
def supports_streaming_stt(self) -> bool:
"""ElevenLabs supports streaming via Scribe Realtime API."""
return True
def supports_streaming_tts(self) -> bool:
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> ElevenLabsStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
# ElevenLabs realtime STT requires scribe_v2_realtime model
# Frontend sends PCM16 at 24kHz, but ElevenLabs expects 16kHz
# The transcriber will resample automatically
transcriber = ElevenLabsStreamingTranscriber(
api_key=self.api_key,
model="scribe_v2_realtime",
input_sample_rate=24000, # What frontend sends
target_sample_rate=16000, # What ElevenLabs expects
language_code="en",
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, _speed: float = 1.0
) -> ElevenLabsStreamingSynthesizer:
"""Create a streaming TTS session."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
synthesizer = ElevenLabsStreamingSynthesizer(
api_key=self.api_key,
voice_id=voice_id,
model_id=self.tts_model,
# Use mp3_44100_64 for streaming - good balance of quality and chunk size
output_format="mp3_44100_64",
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,567 @@
import asyncio
import base64
import io
import json
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
if TYPE_CHECKING:
from openai import AsyncOpenAI
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using OpenAI Realtime API."""
def __init__(self, api_key: str, model: str = "whisper-1"):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"OpenAIStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._current_turn_transcript = "" # Transcript for current VAD turn
self._accumulated_transcript = "" # Accumulated across all turns
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to OpenAI Realtime API."""
self._session = aiohttp.ClientSession()
# OpenAI Realtime transcription endpoint
url = "wss://api.openai.com/v1/realtime?intent=transcription"
headers = {
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Beta": "realtime=v1",
}
try:
self._ws = await self._session.ws_connect(url, headers=headers)
self._logger.info("Connected to OpenAI Realtime API")
except Exception as e:
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
raise
# Configure the session for transcription
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
config_message = {
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
"input_audio_transcription": {
"model": self.model,
},
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 500,
},
},
}
await self._ws.send_str(json.dumps(config_message))
self._logger.info(f"Sent config for model: {self.model} with server VAD")
# Start receiving transcripts
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts."""
if not self._ws:
return
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
msg_type = data.get("type", "")
self._logger.debug(f"Received message type: {msg_type}")
# Handle errors
if msg_type == "error":
error = data.get("error", {})
self._logger.error(f"OpenAI error: {error}")
continue
# Handle VAD events
if msg_type == "input_audio_buffer.speech_started":
self._logger.info("OpenAI: Speech started")
# Reset current turn transcript for new speech
self._current_turn_transcript = ""
continue
elif msg_type == "input_audio_buffer.speech_stopped":
self._logger.info(
"OpenAI: Speech stopped (VAD detected silence)"
)
continue
elif msg_type == "input_audio_buffer.committed":
self._logger.info("OpenAI: Audio buffer committed")
continue
# Handle transcription events
if msg_type == "conversation.item.input_audio_transcription.delta":
delta = data.get("delta", "")
if delta:
self._logger.info(f"OpenAI: Transcription delta: {delta}")
self._current_turn_transcript += delta
# Show accumulated + current turn transcript
full_transcript = self._accumulated_transcript
if full_transcript and self._current_turn_transcript:
full_transcript += " "
full_transcript += self._current_turn_transcript
await self._transcript_queue.put(
TranscriptResult(text=full_transcript, is_vad_end=False)
)
elif (
msg_type
== "conversation.item.input_audio_transcription.completed"
):
transcript = data.get("transcript", "")
if transcript:
self._logger.info(
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
)
# This is the final transcript for this VAD turn
self._current_turn_transcript = transcript
# Accumulate this turn's transcript
if self._accumulated_transcript:
self._accumulated_transcript += " " + transcript
else:
self._accumulated_transcript = transcript
# Send with is_vad_end=True to trigger auto-send
await self._transcript_queue.put(
TranscriptResult(
text=self._accumulated_transcript,
is_vad_end=True,
)
)
elif msg_type not in (
"transcription_session.created",
"transcription_session.updated",
"conversation.item.created",
):
# Log any other message types we might be missing
self._logger.info(
f"OpenAI: Unhandled message type '{msg_type}': {data}"
)
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(f"WebSocket error: {self._ws.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
self._logger.info("WebSocket closed by server")
break
except Exception as e:
self._logger.error(f"Error in receive loop: {e}")
finally:
await self._transcript_queue.put(None)
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to OpenAI."""
if self._ws and not self._closed:
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
# So chunk_bytes / 48000 = duration in seconds
duration_ms = (len(chunk) / 48000) * 1000
self._logger.debug(
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
)
message = {
"type": "input_audio_buffer.append",
"audio": base64.b64encode(chunk).decode("utf-8"),
}
await self._ws.send_str(json.dumps(message))
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._logger.info("OpenAI: Resetting accumulated transcript")
self._accumulated_transcript = ""
self._current_turn_transcript = ""
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Close session and return final transcript."""
self._closed = True
if self._ws:
# With server VAD, the buffer is auto-committed when speech stops.
# But we should still commit any remaining audio and wait for transcription.
try:
await self._ws.send_str(
json.dumps({"type": "input_audio_buffer.commit"})
)
except Exception as e:
self._logger.debug(f"Error sending commit (may be expected): {e}")
# Wait for transcription to arrive (up to 5 seconds)
self._logger.info("Waiting for transcription to complete...")
for _ in range(50): # 50 * 100ms = 5 seconds max
await asyncio.sleep(0.1)
if self._accumulated_transcript:
self._logger.info(
f"Got final transcript: {self._accumulated_transcript[:50]}..."
)
break
else:
self._logger.warning("Timed out waiting for transcription")
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
return self._accumulated_transcript
# OpenAI available voices for TTS
OPENAI_VOICES = [
{"id": "alloy", "name": "Alloy"},
{"id": "echo", "name": "Echo"},
{"id": "fable", "name": "Fable"},
{"id": "onyx", "name": "Onyx"},
{"id": "nova", "name": "Nova"},
{"id": "shimmer", "name": "Shimmer"},
]
# OpenAI available STT models (all support streaming via Realtime API)
OPENAI_STT_MODELS = [
{"id": "whisper-1", "name": "Whisper v1"},
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
]
# OpenAI available TTS models
OPENAI_TTS_MODELS = [
{"id": "tts-1", "name": "TTS-1 (Standard)"},
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
]
def _create_wav_header(
data_length: int,
sample_rate: int = 24000,
channels: int = 1,
bits_per_sample: int = 16,
) -> bytes:
"""Create a WAV file header for PCM audio data."""
import struct
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
# WAV header is 44 bytes
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # ChunkID
36 + data_length, # ChunkSize
b"WAVE", # Format
b"fmt ", # Subchunk1ID
16, # Subchunk1Size (PCM)
1, # AudioFormat (1 = PCM)
channels, # NumChannels
sample_rate, # SampleRate
byte_rate, # ByteRate
block_align, # BlockAlign
bits_per_sample, # BitsPerSample
b"data", # Subchunk2ID
data_length, # Subchunk2Size
)
return header
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
def __init__(
self,
api_key: str,
voice: str = "alloy",
model: str = "tts-1",
speed: float = 1.0,
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice = voice
self.model = model
self.speed = max(0.25, min(4.0, speed))
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
self._synthesis_task: asyncio.Task | None = None
self._closed = False
self._flushed = False
async def connect(self) -> None:
"""Initialize HTTP session for TTS requests."""
self._logger.info("OpenAIStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# Start background task to process text queue
self._synthesis_task = asyncio.create_task(self._process_text_queue())
self._logger.info("OpenAIStreamingSynthesizer: connected")
async def _process_text_queue(self) -> None:
"""Background task to process queued text for synthesis."""
while not self._closed:
try:
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
if text is None:
break
await self._synthesize_text(text)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
except Exception as e:
self._logger.error(f"Error processing text queue: {e}")
async def _synthesize_text(self, text: str) -> None:
"""Make HTTP TTS request and stream audio to queue."""
if not self._session or self._closed:
return
url = "https://api.openai.com/v1/audio/speech"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"voice": self.voice,
"input": text,
"speed": self.speed,
"response_format": "mp3",
}
try:
async with self._session.post(
url, headers=headers, json=payload
) as response:
if response.status != 200:
error_text = await response.text()
self._logger.error(f"OpenAI TTS error: {error_text}")
return
# Use 8192 byte chunks for smoother streaming
# (larger chunks = more complete MP3 frames, better playback)
async for chunk in response.content.iter_chunked(8192):
if self._closed:
break
if chunk:
await self._audio_queue.put(chunk)
except Exception as e:
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
async def send_text(self, text: str) -> None:
"""Queue text to be synthesized via HTTP streaming."""
if not text.strip() or self._closed:
return
await self._text_queue.put(text)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk (MP3 format)."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for synthesis to complete."""
if self._flushed:
return
self._flushed = True
# Signal end of text input
await self._text_queue.put(None)
# Wait for synthesis task to complete processing all text
if self._synthesis_task and not self._synthesis_task.done():
try:
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
except asyncio.TimeoutError:
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
except asyncio.CancelledError:
pass
# Signal end of audio stream
await self._audio_queue.put(None)
async def close(self) -> None:
"""Close the session."""
if self._closed:
return
self._closed = True
# Signal end of queues only if flush wasn't already called
if not self._flushed:
await self._text_queue.put(None)
await self._audio_queue.put(None)
if self._synthesis_task and not self._synthesis_task.done():
self._synthesis_task.cancel()
try:
await self._synthesis_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
class OpenAIVoiceProvider(VoiceProviderInterface):
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.stt_model = stt_model or "whisper-1"
self.tts_model = tts_model or "tts-1"
self.default_voice = default_voice or "alloy"
self._client: "AsyncOpenAI | None" = None
def _get_client(self) -> "AsyncOpenAI":
if self._client is None:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
return self._client
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using OpenAI Whisper.
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
client = self._get_client()
# Create a file-like object from the audio bytes
audio_file = io.BytesIO(audio_data)
audio_file.name = f"audio.{audio_format}"
response = await client.audio.transcriptions.create(
model=self.stt_model,
file=audio_file,
)
return response.text
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using OpenAI TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice identifier (defaults to provider's default voice)
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks (mp3 format)
"""
client = self._get_client()
# Clamp speed to valid range
speed = max(0.25, min(4.0, speed))
# Use with_streaming_response for proper async streaming
# Using 8192 byte chunks for better streaming performance
# (larger chunks = fewer round-trips, more complete MP3 frames)
async with client.audio.speech.with_streaming_response.create(
model=self.tts_model,
voice=voice or self.default_voice,
input=text,
speed=speed,
response_format="mp3",
) as response:
async for chunk in response.iter_bytes(chunk_size=8192):
yield chunk
def get_available_voices(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS voices."""
return OPENAI_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
"""Get available OpenAI STT models."""
return OPENAI_STT_MODELS.copy()
def get_available_tts_models(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS models."""
return OPENAI_TTS_MODELS.copy()
def supports_streaming_stt(self) -> bool:
"""OpenAI supports streaming via Realtime API for all STT models."""
return True
def supports_streaming_tts(self) -> bool:
"""OpenAI supports real-time streaming TTS via Realtime API."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> OpenAIStreamingTranscriber:
"""Create a streaming transcription session using Realtime API."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
transcriber = OpenAIStreamingTranscriber(
api_key=self.api_key,
model=self.stt_model,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> OpenAIStreamingSynthesizer:
"""Create a streaming TTS session using HTTP streaming API."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
synthesizer = OpenAIStreamingSynthesizer(
api_key=self.api_key,
voice=voice or self.default_voice or "alloy",
model=self.tts_model or "tts-1",
speed=speed,
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgAudio = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 32 32"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M4 20V12M10 28V4M22 22V10M28 18V14M16 20V12"
strokeWidth={2.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgAudio;

View File

@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
export { default as SvgAudio } from "@opal/icons/audio";
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
export { default as SvgAws } from "@opal/icons/aws";
export { default as SvgAzure } from "@opal/icons/azure";
@@ -105,6 +106,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
export { default as SvgMcp } from "@opal/icons/mcp";
export { default as SvgMenu } from "@opal/icons/menu";
export { default as SvgMicrophone } from "@opal/icons/microphone";
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
export { default as SvgMinus } from "@opal/icons/minus";
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
export { default as SvgMoon } from "@opal/icons/moon";

View File

@@ -0,0 +1,29 @@
import type { IconProps } from "@opal/types";
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
{/* Microphone body */}
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
{/* Diagonal slash */}
<path
d="M2 2L14 14"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophoneOff;

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgMicrophone = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophone;

View File

@@ -58,7 +58,7 @@ const nextConfig = {
{
key: "Permissions-Policy",
value:
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
},
],
},

8
web/package-lock.json generated
View File

@@ -105,6 +105,7 @@
"@types/node": "18.15.11",
"@types/react": "19.2.10",
"@types/react-dom": "19.2.3",
"@types/sbd": "^1.0.5",
"@types/stats.js": "^0.17.4",
"@types/uuid": "^9.0.8",
"@typescript/native-preview": "7.0.0-dev.20251222.1",
@@ -5543,6 +5544,13 @@
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
"license": "MIT"
},
"node_modules/@types/sbd": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@types/sbd/-/sbd-1.0.5.tgz",
"integrity": "sha512-60PxBBWhg0C3yb5bTP+wwWYGTKMcuB0S6mTEa1sedMC79tYY0Ei7YjU4qsWzGn++lWscLQde16SnElJrf5/aTw==",
"dev": true,
"license": "MIT"
},
"node_modules/@types/stack-utils": {
"version": "2.0.3",
"dev": true,

View File

@@ -121,6 +121,7 @@
"@types/node": "18.15.11",
"@types/react": "19.2.10",
"@types/react-dom": "19.2.3",
"@types/sbd": "^1.0.5",
"@types/stats.js": "^0.17.4",
"@types/uuid": "^9.0.8",
"@typescript/native-preview": "7.0.0-dev.20251222.1",

View File

@@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.5 2H13V14H10.5V2Z" fill="black"/>
<path d="M3 2H5.5V14H3V2Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 192 B

View File

@@ -0,0 +1,466 @@
"use client";
import Image from "next/image";
import { FunctionComponent, useMemo, useState, useEffect } from "react";
import Modal from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import { Vertical, Horizontal } from "@/layouts/input-layouts";
import { toast } from "@/hooks/useToast";
import { Section } from "@/layouts/general-layouts";
import { SvgArrowExchange, SvgOnyxLogo } from "@opal/icons";
import type { IconProps } from "@opal/types";
interface VoiceOption {
value: string;
label: string;
}
interface LLMProviderView {
id: number;
name: string;
provider: string;
api_key: string | null;
}
interface ApiKeyOption {
value: string;
label: string;
description?: string;
}
interface VoiceProviderView {
id: number;
name: string;
provider_type: string;
is_default_stt: boolean;
is_default_tts: boolean;
stt_model: string | null;
tts_model: string | null;
default_voice: string | null;
has_api_key: boolean;
target_uri: string | null;
}
interface VoiceProviderSetupModalProps {
providerType: string;
existingProvider: VoiceProviderView | null;
mode: "stt" | "tts";
defaultModelId?: string | null;
onClose: () => void;
onSuccess: () => void;
}
const PROVIDER_LABELS: Record<string, string> = {
openai: "OpenAI",
azure: "Azure Speech Services",
elevenlabs: "ElevenLabs",
};
const PROVIDER_API_KEY_URLS: Record<string, string> = {
openai: "https://platform.openai.com/api-keys",
azure: "https://portal.azure.com/",
elevenlabs: "https://elevenlabs.io/app/settings/api-keys",
};
const PROVIDER_LOGO_URLS: Record<string, string> = {
openai: "/Openai.svg",
azure: "/Azure.png",
elevenlabs: "/ElevenLabs.svg",
};
const PROVIDER_DOCS_URLS: Record<string, string> = {
openai: "https://platform.openai.com/docs/guides/text-to-speech",
azure: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
elevenlabs: "https://elevenlabs.io/docs",
};
const OPENAI_STT_MODELS = [{ id: "whisper-1", name: "Whisper v1" }];
const OPENAI_TTS_MODELS = [
{ id: "tts-1", name: "TTS-1" },
{ id: "tts-1-hd", name: "TTS-1 HD" },
];
// Map model IDs from cards to actual API model IDs
const MODEL_ID_MAP: Record<string, string> = {
"tts-1": "tts-1",
"tts-1-hd": "tts-1-hd",
whisper: "whisper-1",
};
export default function VoiceProviderSetupModal({
providerType,
existingProvider,
mode,
defaultModelId,
onClose,
onSuccess,
}: VoiceProviderSetupModalProps) {
// Map the card model ID to the actual API model ID
// Prioritize defaultModelId (from the clicked card) over stored value
const initialTtsModel = defaultModelId
? MODEL_ID_MAP[defaultModelId] ?? "tts-1"
: existingProvider?.tts_model ?? "tts-1";
const [apiKey, setApiKey] = useState("");
const [apiKeyChanged, setApiKeyChanged] = useState(false);
const [targetUri, setTargetUri] = useState(
existingProvider?.target_uri ?? ""
);
const [selectedLlmProviderId, setSelectedLlmProviderId] = useState<
number | null
>(null);
const [sttModel, setSttModel] = useState(
existingProvider?.stt_model ?? "whisper-1"
);
const [ttsModel, setTtsModel] = useState(initialTtsModel);
const [defaultVoice, setDefaultVoice] = useState(
existingProvider?.default_voice ?? ""
);
const [isSubmitting, setIsSubmitting] = useState(false);
// Dynamic voices fetched from backend
const [voiceOptions, setVoiceOptions] = useState<VoiceOption[]>([]);
const [isLoadingVoices, setIsLoadingVoices] = useState(false);
// Existing OpenAI LLM providers for API key reuse
const [existingApiKeyOptions, setExistingApiKeyOptions] = useState<
ApiKeyOption[]
>([]);
const [llmProviderMap, setLlmProviderMap] = useState<Map<string, number>>(
new Map()
);
// Fetch existing OpenAI LLM providers (for API key reuse)
useEffect(() => {
if (providerType !== "openai") return;
fetch("/api/admin/llm/provider")
.then((res) => res.json())
.then((data: LLMProviderView[]) => {
const openaiProviders = data.filter(
(p) => p.provider === "openai" && p.api_key
);
const options: ApiKeyOption[] = openaiProviders.map((p) => ({
value: p.api_key!,
label: p.api_key!,
description: `Used for LLM provider ${p.name}`,
}));
setExistingApiKeyOptions(options);
// Map masked API keys to provider IDs for lookup on selection
const providerMap = new Map<string, number>();
openaiProviders.forEach((p) => {
if (p.api_key) {
providerMap.set(p.api_key, p.id);
}
});
setLlmProviderMap(providerMap);
})
.catch(() => {
setExistingApiKeyOptions([]);
});
}, [providerType]);
// Fetch voices on mount (works without API key for ElevenLabs/OpenAI)
useEffect(() => {
setIsLoadingVoices(true);
fetch(`/api/admin/voice/voices?provider_type=${providerType}`)
.then((res) => res.json())
.then((data: Array<{ id: string; name: string }>) => {
const options = data.map((v) => ({ value: v.id, label: v.name }));
setVoiceOptions(options);
// Set default voice to first option if not already set
const firstOption = options[0];
if (firstOption) {
setDefaultVoice((prev) => prev || firstOption.value);
}
})
.catch(() => {
setVoiceOptions([]);
})
.finally(() => {
setIsLoadingVoices(false);
});
}, [providerType]);
const isEditing = !!existingProvider;
const label = PROVIDER_LABELS[providerType] ?? providerType;
// Create a logo arrangement component for the modal header
const LogoArrangement: FunctionComponent<IconProps> = useMemo(() => {
const Component: FunctionComponent<IconProps> = () => (
<div className="flex items-center gap-1">
<div className="flex items-center justify-center size-7 shrink-0 overflow-clip">
<Image
src={PROVIDER_LOGO_URLS[providerType] ?? "/Openai.svg"}
alt={`${label} logo`}
width={24}
height={24}
/>
</div>
<div className="flex items-center justify-center size-4 p-0.5 shrink-0">
<SvgArrowExchange className="size-3 text-text-04" />
</div>
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
</div>
</div>
);
return Component;
}, [providerType, label]);
const handleSubmit = async () => {
if (!isEditing && !apiKey && !selectedLlmProviderId) {
toast.error("API key is required");
return;
}
if (providerType === "azure" && !isEditing && !targetUri) {
toast.error("Target URI is required");
return;
}
setIsSubmitting(true);
try {
// Test the connection first (skip if reusing LLM provider key - it's already validated)
if (!selectedLlmProviderId) {
const testResponse = await fetch("/api/admin/voice/providers/test", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: providerType,
api_key: apiKeyChanged ? apiKey : undefined,
target_uri: targetUri || undefined,
use_stored_key: isEditing && !apiKeyChanged,
}),
});
if (!testResponse.ok) {
const data = await testResponse.json();
toast.error(data.detail || "Connection test failed");
setIsSubmitting(false);
return;
}
}
// Save the provider
const response = await fetch("/api/admin/voice/providers", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
id: existingProvider?.id,
name: label,
provider_type: providerType,
api_key: selectedLlmProviderId
? undefined
: apiKeyChanged
? apiKey
: undefined,
api_key_changed: selectedLlmProviderId ? false : apiKeyChanged,
target_uri: targetUri || undefined,
llm_provider_id: selectedLlmProviderId,
stt_model: sttModel,
tts_model: ttsModel,
default_voice: defaultVoice,
activate_stt: mode === "stt",
activate_tts: mode === "tts",
}),
});
if (response.ok) {
toast.success(isEditing ? "Provider updated" : "Provider connected");
onSuccess();
} else {
const data = await response.json();
toast.error(data.detail || "Failed to save provider");
}
} catch {
toast.error("Failed to save provider");
} finally {
setIsSubmitting(false);
}
};
return (
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
<Modal.Content width="sm">
<Modal.Header
icon={LogoArrangement}
title={isEditing ? `Edit ${label}` : `Set up ${label}`}
description={`Connect to ${label} and set up your voice models.`}
onClose={onClose}
/>
<Modal.Body>
<Section gap={1} alignItems="stretch">
<Vertical
title="API Key"
subDescription={
isEditing ? (
"Leave blank to keep existing key"
) : (
<>
Paste your{" "}
<a
href={PROVIDER_API_KEY_URLS[providerType]}
target="_blank"
rel="noopener noreferrer"
className="underline"
>
API key
</a>{" "}
from {label} to access your models.
</>
)
}
nonInteractive
>
{providerType === "openai" && existingApiKeyOptions.length > 0 ? (
<InputComboBox
placeholder={isEditing ? "••••••••" : "Enter API key"}
value={apiKey}
onChange={(e) => {
setApiKey(e.target.value);
setApiKeyChanged(true);
setSelectedLlmProviderId(null);
}}
onValueChange={(value) => {
setApiKey(value);
// Check if this is an existing key
const llmProviderId = llmProviderMap.get(value);
if (llmProviderId) {
setSelectedLlmProviderId(llmProviderId);
setApiKeyChanged(false);
} else {
setSelectedLlmProviderId(null);
setApiKeyChanged(true);
}
}}
options={existingApiKeyOptions}
separatorLabel="Reuse OpenAI API Keys"
strict={false}
showAddPrefix
/>
) : (
<InputTypeIn
type="password"
placeholder={isEditing ? "••••••••" : "Enter API key"}
value={apiKey}
onChange={(e) => {
setApiKey(e.target.value);
setApiKeyChanged(true);
}}
/>
)}
</Vertical>
{providerType === "azure" && (
<Vertical
title="Target URI"
subDescription={
<>
Paste the endpoint shown in{" "}
<a
href="https://portal.azure.com/"
target="_blank"
rel="noopener noreferrer"
className="underline"
>
Azure Portal (Keys and Endpoint)
</a>
. Onyx extracts the speech region from this URL. Examples:
https://westus.api.cognitive.microsoft.com/ or
https://westus.tts.speech.microsoft.com/.
</>
}
nonInteractive
>
<InputTypeIn
placeholder={
isEditing
? "Leave blank to keep existing"
: "https://<region>.api.cognitive.microsoft.com/"
}
value={targetUri}
onChange={(e) => setTargetUri(e.target.value)}
/>
</Vertical>
)}
{providerType === "openai" && mode === "stt" && (
<Horizontal title="STT Model" center nonInteractive>
<InputSelect value={sttModel} onValueChange={setSttModel}>
<InputSelect.Trigger />
<InputSelect.Content>
{OPENAI_STT_MODELS.map((model) => (
<InputSelect.Item key={model.id} value={model.id}>
{model.name}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelect>
</Horizontal>
)}
{providerType === "openai" && mode === "tts" && (
<Vertical
title="Default Model"
subDescription="This model will be used by Onyx by default for text-to-speech."
nonInteractive
>
<InputSelect value={ttsModel} onValueChange={setTtsModel}>
<InputSelect.Trigger />
<InputSelect.Content>
{OPENAI_TTS_MODELS.map((model) => (
<InputSelect.Item key={model.id} value={model.id}>
{model.name}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelect>
</Vertical>
)}
{mode === "tts" && (
<Vertical
title="Voice"
subDescription="This voice will be used for spoken responses."
nonInteractive
>
<InputSelect
value={defaultVoice}
onValueChange={setDefaultVoice}
disabled={isLoadingVoices}
>
<InputSelect.Trigger
placeholder={
isLoadingVoices ? "Loading voices..." : "Select voice"
}
/>
<InputSelect.Content>
{voiceOptions.map((voice) => (
<InputSelect.Item key={voice.value} value={voice.value}>
{voice.label}
</InputSelect.Item>
))}
</InputSelect.Content>
</InputSelect>
</Vertical>
)}
</Section>
</Modal.Body>
<Modal.Footer>
<Button secondary onClick={onClose}>
Cancel
</Button>
<Button onClick={handleSubmit} disabled={isSubmitting}>
{isSubmitting ? "Connecting..." : isEditing ? "Save" : "Connect"}
</Button>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -0,0 +1,633 @@
"use client";
import Image from "next/image";
import { useMemo, useState } from "react";
import { AdminPageTitle } from "@/components/admin/Title";
import { InfoIcon } from "@/components/icons/icons";
import Text from "@/refresh-components/texts/Text";
import Separator from "@/refresh-components/Separator";
import useSWR from "swr";
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
import { ThreeDotsLoader } from "@/components/Loading";
import { Callout } from "@/components/ui/callout";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { cn } from "@/lib/utils";
import {
SvgArrowExchange,
SvgArrowRightCircle,
SvgAudio,
SvgCheckSquare,
SvgEdit,
SvgMicrophone,
SvgX,
} from "@opal/icons";
import VoiceProviderSetupModal from "./VoiceProviderSetupModal";
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
interface VoiceProviderView {
id: number;
name: string;
provider_type: string;
is_default_stt: boolean;
is_default_tts: boolean;
stt_model: string | null;
tts_model: string | null;
default_voice: string | null;
has_api_key: boolean;
target_uri: string | null;
}
interface ModelDetails {
id: string;
label: string;
subtitle: string;
logoSrc?: string;
providerType: string;
}
interface ProviderGroup {
providerType: string;
providerLabel: string;
logoSrc?: string;
models: ModelDetails[];
}
// STT Models - individual cards
const STT_MODELS: ModelDetails[] = [
{
id: "whisper",
label: "Whisper",
subtitle: "OpenAI's general purpose speech recognition model.",
logoSrc: "/Openai.svg",
providerType: "openai",
},
{
id: "azure-speech-stt",
label: "Azure Speech",
subtitle: "Speech to text in Microsoft Foundry Tools.",
logoSrc: "/Azure.png",
providerType: "azure",
},
{
id: "elevenlabs-stt",
label: "ElevenAPI",
subtitle: "ElevenLabs Speech to Text API.",
logoSrc: "/ElevenLabs.svg",
providerType: "elevenlabs",
},
];
// TTS Models - grouped by provider
const TTS_PROVIDER_GROUPS: ProviderGroup[] = [
{
providerType: "openai",
providerLabel: "OpenAI",
logoSrc: "/Openai.svg",
models: [
{
id: "tts-1",
label: "TTS-1",
subtitle: "OpenAI's text-to-speech model optimized for speed.",
logoSrc: "/Openai.svg",
providerType: "openai",
},
{
id: "tts-1-hd",
label: "TTS-1 HD",
subtitle: "OpenAI's text-to-speech model optimized for quality.",
logoSrc: "/Openai.svg",
providerType: "openai",
},
],
},
{
providerType: "azure",
providerLabel: "Azure",
logoSrc: "/Azure.png",
models: [
{
id: "azure-speech-tts",
label: "Azure Speech",
subtitle: "Text to speech in Microsoft Foundry Tools.",
logoSrc: "/Azure.png",
providerType: "azure",
},
],
},
{
providerType: "elevenlabs",
providerLabel: "ElevenLabs",
logoSrc: "/ElevenLabs.svg",
models: [
{
id: "elevenlabs-tts",
label: "ElevenAPI",
subtitle: "ElevenLabs Text to Speech API.",
logoSrc: "/ElevenLabs.svg",
providerType: "elevenlabs",
},
],
},
];
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
isHovered: boolean;
onMouseEnter: () => void;
onMouseLeave: () => void;
children: React.ReactNode;
}
function HoverIconButton({
isHovered,
onMouseEnter,
onMouseLeave,
children,
...buttonProps
}: HoverIconButtonProps) {
return (
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
{children}
</Button>
</div>
);
}
type ProviderMode = "stt" | "tts";
export default function VoiceConfigurationPage() {
const [modalOpen, setModalOpen] = useState(false);
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
const [editingProvider, setEditingProvider] =
useState<VoiceProviderView | null>(null);
const [modalMode, setModalMode] = useState<ProviderMode>("stt");
const [selectedModelId, setSelectedModelId] = useState<string | null>(null);
const [sttActivationError, setSTTActivationError] = useState<string | null>(
null
);
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
null
);
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
const {
data: providers,
error,
isLoading,
mutate,
} = useSWR<VoiceProviderView[]>(VOICE_PROVIDERS_URL, errorHandlingFetcher);
const handleConnect = (
providerType: string,
mode: ProviderMode,
modelId?: string
) => {
setSelectedProvider(providerType);
setEditingProvider(null);
setModalMode(mode);
setSelectedModelId(modelId ?? null);
setModalOpen(true);
setSTTActivationError(null);
setTTSActivationError(null);
};
const handleEdit = (
provider: VoiceProviderView,
mode: ProviderMode,
modelId?: string
) => {
setSelectedProvider(provider.provider_type);
setEditingProvider(provider);
setModalMode(mode);
setSelectedModelId(modelId ?? null);
setModalOpen(true);
};
const handleSetDefault = async (
providerId: number,
mode: ProviderMode,
modelId?: string
) => {
const setError =
mode === "stt" ? setSTTActivationError : setTTSActivationError;
setError(null);
try {
const url = new URL(
`${VOICE_PROVIDERS_URL}/${providerId}/activate-${mode}`,
window.location.origin
);
if (mode === "tts" && modelId) {
url.searchParams.set("tts_model", modelId);
}
const response = await fetch(url.toString(), { method: "POST" });
if (!response.ok) {
const errorBody = await response.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: `Failed to set provider as default ${mode.toUpperCase()}.`
);
}
await mutate();
} catch (err) {
const message =
err instanceof Error ? err.message : "Unexpected error occurred.";
setError(message);
}
};
const handleDeactivate = async (providerId: number, mode: ProviderMode) => {
const setError =
mode === "stt" ? setSTTActivationError : setTTSActivationError;
setError(null);
try {
const response = await fetch(
`${VOICE_PROVIDERS_URL}/${providerId}/deactivate-${mode}`,
{ method: "POST" }
);
if (!response.ok) {
const errorBody = await response.json().catch(() => ({}));
throw new Error(
typeof errorBody?.detail === "string"
? errorBody.detail
: `Failed to deactivate ${mode.toUpperCase()} provider.`
);
}
await mutate();
} catch (err) {
const message =
err instanceof Error ? err.message : "Unexpected error occurred.";
setError(message);
}
};
const handleModalClose = () => {
setModalOpen(false);
setSelectedProvider(null);
setEditingProvider(null);
setSelectedModelId(null);
};
const handleModalSuccess = () => {
mutate();
handleModalClose();
};
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
return !!provider?.has_api_key;
};
// Map provider types to their configured provider data
const providersByType = useMemo(() => {
return new Map((providers ?? []).map((p) => [p.provider_type, p] as const));
}, [providers]);
const hasActiveSTTProvider =
providers?.some((p) => p.is_default_stt) ?? false;
const hasActiveTTSProvider =
providers?.some((p) => p.is_default_tts) ?? false;
const renderLogo = ({
logoSrc,
alt,
size = 16,
}: {
logoSrc?: string;
alt: string;
size?: number;
}) => {
const containerSizeClass = size === 24 ? "size-7" : "size-5";
return (
<div
className={cn(
"flex items-center justify-center px-0.5 py-0 shrink-0 overflow-clip",
containerSizeClass
)}
>
{logoSrc ? (
<Image src={logoSrc} alt={alt} width={size} height={size} />
) : (
<SvgMicrophone size={size} className="text-text-02" />
)}
</div>
);
};
const renderModelCard = ({
model,
mode,
}: {
model: ModelDetails;
mode: ProviderMode;
}) => {
const provider = providersByType.get(model.providerType);
const isConfigured = isProviderConfigured(provider);
// For TTS, also check that this specific model is the default (not just the provider)
const isActive =
mode === "stt"
? provider?.is_default_stt
: provider?.is_default_tts && provider?.tts_model === model.id;
const isHighlighted = isActive ?? false;
const providerId = provider?.id;
const buttonState = (() => {
if (!provider || !isConfigured) {
return {
label: "Connect",
disabled: false,
icon: "arrow" as const,
onClick: () => handleConnect(model.providerType, mode, model.id),
};
}
if (isActive) {
return {
label: "Current Default",
disabled: false,
icon: "check" as const,
onClick: providerId
? () => handleDeactivate(providerId, mode)
: undefined,
};
}
return {
label: "Set as Default",
disabled: false,
icon: "arrow-circle" as const,
onClick: providerId
? () => handleSetDefault(providerId, mode, model.id)
: undefined,
};
})();
const buttonKey = `${mode}-${model.id}`;
const isButtonHovered = hoveredButtonKey === buttonKey;
const isCardClickable =
buttonState.icon === "arrow" &&
typeof buttonState.onClick === "function" &&
!buttonState.disabled;
const handleCardClick = () => {
if (isCardClickable) {
buttonState.onClick?.();
}
};
return (
<div
key={`${mode}-${model.id}`}
onClick={isCardClickable ? handleCardClick : undefined}
className={cn(
"flex items-start justify-between gap-3 rounded-16 border p-2 bg-background-neutral-01",
isHighlighted ? "border-action-link-05" : "border-border-01",
isCardClickable &&
"cursor-pointer hover:bg-background-tint-01 transition-colors"
)}
>
<div className="flex flex-1 items-start gap-1 p-2">
{renderLogo({
logoSrc: model.logoSrc,
alt: `${model.label} logo`,
size: 16,
})}
<div className="flex flex-col">
<Text as="p" mainUiAction text04>
{model.label}
</Text>
<Text as="p" secondaryBody text03>
{model.subtitle}
</Text>
</div>
</div>
<div className="flex items-center justify-end gap-2">
{isConfigured && (
<OpalButton
icon={SvgEdit}
tooltip="Edit"
prominence="tertiary"
size="sm"
onClick={(e) => {
e.stopPropagation();
if (provider) handleEdit(provider, mode, model.id);
}}
aria-label={`Edit ${model.label}`}
/>
)}
{buttonState.icon === "check" ? (
<HoverIconButton
isHovered={isButtonHovered}
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
onMouseLeave={() => setHoveredButtonKey(null)}
action={true}
tertiary
disabled={buttonState.disabled}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
>
{buttonState.label}
</HoverIconButton>
) : (
<Button
action={false}
tertiary
disabled={buttonState.disabled || !buttonState.onClick}
onClick={(e) => {
e.stopPropagation();
buttonState.onClick?.();
}}
rightIcon={
buttonState.icon === "arrow"
? SvgArrowExchange
: buttonState.icon === "arrow-circle"
? SvgArrowRightCircle
: undefined
}
>
{buttonState.label}
</Button>
)}
</div>
</div>
);
};
if (error) {
const message = error?.message || "Unable to load voice configuration.";
const detail =
error instanceof FetchError && typeof error.info?.detail === "string"
? error.info.detail
: undefined;
return (
<>
<AdminPageTitle
title="Voice"
icon={SvgMicrophone}
includeDivider={false}
/>
<Callout type="danger" title="Failed to load voice settings">
{message}
{detail && (
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
{detail}
</Text>
)}
</Callout>
</>
);
}
if (isLoading) {
return (
<>
<AdminPageTitle
title="Voice"
icon={SvgMicrophone}
includeDivider={false}
/>
<div className="mt-8">
<ThreeDotsLoader />
</div>
</>
);
}
return (
<>
<AdminPageTitle icon={SvgAudio} title="Voice" />
<div className="pt-4 pb-4">
<Text as="p" secondaryBody text03>
Speech to text (STT) and text to speech (TTS) capabilities.
</Text>
</div>
<Separator />
<div className="flex w-full flex-col gap-8 pb-6">
{/* Speech-to-Text Section */}
<div className="flex w-full max-w-[960px] flex-col gap-3">
<div className="flex flex-col">
<Text as="p" mainContentEmphasis text04>
Speech to Text
</Text>
<Text as="p" secondaryBody text03>
Select a model to transcribe speech to text in chats.
</Text>
</div>
{sttActivationError && (
<Callout type="danger" title="Unable to update STT provider">
{sttActivationError}
</Callout>
)}
{!hasActiveSTTProvider && (
<div
className="flex items-start rounded-16 border p-2"
style={{
backgroundColor: "var(--status-info-00)",
borderColor: "var(--status-info-02)",
}}
>
<div className="flex items-start gap-1 p-2">
<div
className="flex size-5 items-center justify-center rounded-full p-0.5"
style={{
backgroundColor: "var(--status-info-01)",
}}
>
<div style={{ color: "var(--status-text-info-05)" }}>
<InfoIcon size={16} />
</div>
</div>
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
Connect a speech to text provider to use in chat.
</Text>
</div>
</div>
)}
<div className="flex flex-col gap-2">
{STT_MODELS.map((model) => renderModelCard({ model, mode: "stt" }))}
</div>
</div>
{/* Text-to-Speech Section */}
<div className="flex w-full max-w-[960px] flex-col gap-3">
<div className="flex flex-col">
<Text as="p" mainContentEmphasis text04>
Text to Speech
</Text>
<Text as="p" secondaryBody text03>
Select a model to speak out chat responses.
</Text>
</div>
{ttsActivationError && (
<Callout type="danger" title="Unable to update TTS provider">
{ttsActivationError}
</Callout>
)}
{!hasActiveTTSProvider && (
<div
className="flex items-start rounded-16 border p-2"
style={{
backgroundColor: "var(--status-info-00)",
borderColor: "var(--status-info-02)",
}}
>
<div className="flex items-start gap-1 p-2">
<div
className="flex size-5 items-center justify-center rounded-full p-0.5"
style={{
backgroundColor: "var(--status-info-01)",
}}
>
<div style={{ color: "var(--status-text-info-05)" }}>
<InfoIcon size={16} />
</div>
</div>
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
Connect a text to speech provider to use in chat.
</Text>
</div>
</div>
)}
<div className="flex flex-col gap-4">
{TTS_PROVIDER_GROUPS.map((group) => (
<div key={group.providerType} className="flex flex-col gap-2">
<Text as="p" secondaryBody text03 className="px-0.5">
{group.providerLabel}
</Text>
<div className="flex flex-col gap-2">
{group.models.map((model) =>
renderModelCard({ model, mode: "tts" })
)}
</div>
</div>
))}
</div>
</div>
</div>
{modalOpen && selectedProvider && (
<VoiceProviderSetupModal
providerType={selectedProvider}
existingProvider={editingProvider}
mode={modalMode}
defaultModelId={selectedModelId}
onClose={handleModalClose}
onSuccess={handleModalSuccess}
/>
)}
</>
);
}

View File

@@ -1,133 +0,0 @@
import { SvgArrowUpRight, SvgUserSync } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
import { Button } from "@opal/components";
import { Section } from "@/layouts/general-layouts";
import Card from "@/refresh-components/cards/Card";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import Link from "next/link";
import { ADMIN_PATHS } from "@/lib/admin-routes";
// ---------------------------------------------------------------------------
// Stats cell
// ---------------------------------------------------------------------------
interface StatCellProps {
value: number | null;
label: string;
}
function StatCell({ value, label }: StatCellProps) {
const display = value === null ? "—" : value.toLocaleString();
return (
<Section alignItems="start" gap={0.25} width="fit" padding={0.5}>
<Text as="p" headingH3 text05>
{display}
</Text>
<Text as="p" mainUiMuted text03>
{label}
</Text>
</Section>
);
}
// ---------------------------------------------------------------------------
// SCIM card
// ---------------------------------------------------------------------------
function ScimCard() {
return (
<Card gap={0.5} padding={0.75}>
<ContentAction
icon={SvgUserSync}
title="SCIM Sync"
description="Users are synced from your identity provider."
sizePreset="main-ui"
variant="section"
paddingVariant="fit"
rightChildren={
<Link href={ADMIN_PATHS.SCIM}>
<Button prominence="tertiary" rightIcon={SvgArrowUpRight} size="sm">
Manage
</Button>
</Link>
}
/>
</Card>
);
}
// ---------------------------------------------------------------------------
// Stats bar — layout varies by SCIM status
// ---------------------------------------------------------------------------
interface StatsBarProps {
activeUsers: number | null;
pendingInvites: number | null;
requests: number | null;
showScim: boolean;
}
export default function StatsBar({
activeUsers,
pendingInvites,
requests,
showScim,
}: StatsBarProps) {
if (showScim) {
// With SCIM: one card containing all 3 stats (dividers) + separate SCIM card
return (
<Section
flexDirection="row"
justifyContent="start"
alignItems="stretch"
gap={0.5}
>
<Card padding={0}>
<Section
flexDirection="row"
alignItems="stretch"
gap={0}
width="fit"
height="auto"
>
<StatCell value={activeUsers} label="active users" />
<Separator orientation="vertical" noPadding />
<StatCell value={pendingInvites} label="pending invites" />
{requests !== null && (
<>
<Separator orientation="vertical" noPadding />
<StatCell value={requests} label="requests to join" />
</>
)}
</Section>
</Card>
<ScimCard />
</Section>
);
}
// Without SCIM: 3 separate cards
return (
<Section
flexDirection="row"
justifyContent="start"
alignItems="stretch"
gap={0.5}
>
<Card padding={0.5}>
<StatCell value={activeUsers} label="active users" />
</Card>
<Card padding={0.5}>
<StatCell value={pendingInvites} label="pending invites" />
</Card>
{requests !== null && (
<Card padding={0.5}>
<StatCell value={requests} label="requests to join" />
</Card>
)}
</Section>
);
}

View File

@@ -1,94 +0,0 @@
"use client";
import { useState } from "react";
import { SvgUser, SvgUserPlus } from "@opal/icons";
import { Button } from "@opal/components";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { useScimToken } from "@/hooks/useScimToken";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import type { InvitedUserSnapshot } from "@/lib/types";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import StatsBar from "./StatsBar";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface PaginatedResponse {
items: unknown[];
total_items: number;
}
// ---------------------------------------------------------------------------
// Users page content
// ---------------------------------------------------------------------------
function UsersContent() {
const isEe = usePaidEnterpriseFeaturesEnabled();
const { data: scimToken } = useScimToken();
const showScim = isEe && !!scimToken;
// Active user count — lightweight fetch (page_size=1 to minimize payload)
const { data: activeData } = useSWR<PaginatedResponse>(
"/api/manage/users/accepted?page_num=0&page_size=1",
errorHandlingFetcher
);
const { data: invitedUsers } = useSWR<InvitedUserSnapshot[]>(
"/api/manage/users/invited",
errorHandlingFetcher
);
const { data: pendingUsers } = useSWR<InvitedUserSnapshot[]>(
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
errorHandlingFetcher
);
const activeCount = activeData?.total_items ?? null;
const invitedCount = invitedUsers?.length ?? null;
const pendingCount = pendingUsers?.length ?? null;
return (
<>
<StatsBar
activeUsers={activeCount}
pendingInvites={invitedCount}
requests={pendingCount}
showScim={showScim}
/>
{/* Table and filters will be added in subsequent PRs */}
</>
);
}
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
export default function Page() {
// TODO (ENG-3806): Wire up invite modal in a future PR
const [_showInviteModal, setShowInviteModal] = useState(false);
return (
<SettingsLayouts.Root width="lg">
<SettingsLayouts.Header
title="Users & Requests"
icon={SvgUser}
rightChildren={
<Button icon={SvgUserPlus} onClick={() => setShowInviteModal(true)}>
Invite Users
</Button>
}
/>
<SettingsLayouts.Body>
<UsersContent />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -3,6 +3,7 @@ import type { Route } from "next";
import { unstable_noStore as noStore } from "next/cache";
import { requireAuth } from "@/lib/auth/requireAuth";
import { ProjectsProvider } from "@/providers/ProjectsContext";
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
import AppSidebar from "@/sections/sidebar/AppSidebar";
export interface LayoutProps {
@@ -21,10 +22,12 @@ export default async function Layout({ children }: LayoutProps) {
return (
<ProjectsProvider>
<div className="flex flex-row w-full h-full">
<AppSidebar />
{children}
</div>
<VoiceModeProvider>
<div className="flex flex-row w-full h-full">
<AppSidebar />
{children}
</div>
</VoiceModeProvider>
</ProjectsProvider>
);
}

View File

@@ -1,6 +1,6 @@
"use client";
import React, { useRef, RefObject, useMemo } from "react";
import React, { useRef, RefObject, useMemo, useEffect } from "react";
import { Packet, StopReason } from "@/app/app/services/streamingModels";
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
import { FeedbackType } from "@/app/app/interfaces";
@@ -14,6 +14,9 @@ import { LlmDescriptor, LlmManager } from "@/lib/hooks";
import { Message } from "@/app/app/interfaces";
import Text from "@/refresh-components/texts/Text";
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
import { getTextContent } from "@/app/app/services/packetUtils";
import { removeThinkingTokens } from "@/app/app/services/thinkingTokens";
// Type for the regeneration factory function passed from ChatUI
export type RegenerationFactory = (regenerationRequest: {
@@ -73,6 +76,7 @@ function arePropsEqual(
const AgentMessage = React.memo(function AgentMessage({
rawPackets,
packetCount,
chatState,
nodeId,
messageId,
@@ -158,6 +162,46 @@ const AgentMessage = React.memo(function AgentMessage({
onMessageSelection,
});
// Streaming TTS integration
const { streamTTS, resetTTS } = useVoiceMode();
const ttsCompletedRef = useRef(false);
const streamTTSRef = useRef(streamTTS);
// Keep streamTTS ref in sync without triggering effect re-runs
useEffect(() => {
streamTTSRef.current = streamTTS;
}, [streamTTS]);
// Stream TTS as text content arrives - only for messages still streaming
// Uses ref for streamTTS to avoid re-triggering when its identity changes
// Note: packetCount is used instead of rawPackets because the array is mutated in place
useEffect(() => {
// Skip if we've already finished TTS for this message
if (ttsCompletedRef.current) return;
const textContent = removeThinkingTokens(getTextContent(rawPackets));
if (typeof textContent === "string" && textContent.length > 0) {
streamTTSRef.current(textContent, isComplete);
// Mark as completed once the message is done streaming
if (isComplete) {
ttsCompletedRef.current = true;
}
}
}, [packetCount, isComplete, rawPackets]); // packetCount triggers on new packets since rawPackets is mutated in place
// Reset TTS completed flag when nodeId changes (new message)
useEffect(() => {
ttsCompletedRef.current = false;
}, [nodeId]);
// Reset TTS when component unmounts or nodeId changes
useEffect(() => {
return () => {
resetTTS();
};
}, [nodeId, resetTTS]);
return (
<div
className="flex flex-col gap-3"

View File

@@ -29,6 +29,7 @@ import FeedbackModal, {
FeedbackModalProps,
} from "@/sections/modals/FeedbackModal";
import { Button } from "@opal/components";
import TTSButton from "./TTSButton";
// Wrapper component for SourceTag in toolbar to handle memoization
const SourcesTagWrapper = React.memo(function SourcesTagWrapper({
@@ -268,6 +269,9 @@ export default function MessageToolbar({
}
data-testid="AgentMessage/dislike-button"
/>
<TTSButton
text={removeThinkingTokens(getTextContent(rawPackets)) as string}
/>
{onRegenerate &&
messageId !== undefined &&

View File

@@ -0,0 +1,84 @@
"use client";
import { useCallback, useEffect } from "react";
import { SvgPlayCircle, SvgStop } from "@opal/icons";
import { Button } from "@opal/components";
import { useVoicePlayback } from "@/hooks/useVoicePlayback";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
import { toast } from "@/hooks/useToast";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
interface TTSButtonProps {
text: string;
voice?: string;
speed?: number;
}
function TTSButton({ text, voice, speed }: TTSButtonProps) {
const { isPlaying, isLoading, error, play, pause, stop } = useVoicePlayback();
const { isTTSPlaying, isTTSLoading, stopTTS } = useVoiceMode();
const isGlobalTTSActive = isTTSPlaying || isTTSLoading;
const isButtonPlaying = isGlobalTTSActive || isPlaying;
const isButtonLoading = !isGlobalTTSActive && isLoading;
const handleClick = useCallback(async () => {
if (isGlobalTTSActive) {
// Stop auto-playback voice mode stream from the toolbar button.
stopTTS({ manual: true });
stop();
} else if (isPlaying) {
pause();
} else if (isButtonLoading) {
stop();
} else {
try {
// Ensure no voice-mode stream is active before starting manual playback.
stopTTS();
await play(text, voice, speed);
} catch {
toast.error("Could not play audio");
}
}
}, [
isGlobalTTSActive,
isPlaying,
isButtonLoading,
text,
voice,
speed,
play,
pause,
stop,
stopTTS,
]);
useEffect(() => {
if (error) {
toast.error(error);
}
}, [error]);
const icon = isButtonLoading
? SimpleLoader
: isButtonPlaying
? SvgStop
: SvgPlayCircle;
const tooltip = isButtonPlaying
? "Stop playback"
: isButtonLoading
? "Loading..."
: "Read aloud";
return (
<Button
icon={icon}
onClick={handleClick}
tooltip={tooltip}
data-testid="AgentMessage/tts-button"
/>
);
}
export default TTSButton;

View File

@@ -566,6 +566,21 @@ textarea {
animation: fadeIn 0.2s ease-out forwards;
}
/* Recording waveform animation */
@keyframes waveform {
0%,
100% {
transform: scaleY(0.3);
}
50% {
transform: scaleY(1);
}
}
.animate-waveform {
animation: waveform 0.8s ease-in-out infinite;
}
.container {
margin-bottom: 1rem;
}

View File

@@ -31,7 +31,6 @@ const SETTINGS_LAYOUT_PREFIXES = [
ADMIN_PATHS.LLM_MODELS,
ADMIN_PATHS.AGENTS,
ADMIN_PATHS.USERS,
ADMIN_PATHS.USERS_V2,
ADMIN_PATHS.TOKEN_RATE_LIMITS,
ADMIN_PATHS.SEARCH_SETTINGS,
ADMIN_PATHS.DOCUMENT_PROCESSING,

View File

@@ -0,0 +1,122 @@
import { useState, useRef, useCallback } from "react";
export interface UseVoicePlaybackReturn {
isPlaying: boolean;
isLoading: boolean;
error: string | null;
play: (text: string, voice?: string, speed?: number) => Promise<void>;
pause: () => void;
stop: () => void;
}
export function useVoicePlayback(): UseVoicePlaybackReturn {
const [isPlaying, setIsPlaying] = useState(false);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const audioRef = useRef<HTMLAudioElement | null>(null);
const audioUrlRef = useRef<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const stop = useCallback(() => {
// Revoke object URL to prevent memory leak
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
audioUrlRef.current = null;
}
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.src = "";
audioRef.current = null;
}
if (abortControllerRef.current) {
abortControllerRef.current.abort();
abortControllerRef.current = null;
}
setIsPlaying(false);
setIsLoading(false);
}, []);
const pause = useCallback(() => {
if (audioRef.current && isPlaying) {
audioRef.current.pause();
setIsPlaying(false);
}
}, [isPlaying]);
const play = useCallback(
async (text: string, voice?: string, speed?: number) => {
// Stop any existing playback
stop();
setError(null);
setIsLoading(true);
try {
abortControllerRef.current = new AbortController();
const params = new URLSearchParams();
params.set("text", text);
if (voice) params.set("voice", voice);
if (speed !== undefined) params.set("speed", speed.toString());
const response = await fetch(`/api/voice/synthesize?${params}`, {
method: "POST",
signal: abortControllerRef.current.signal,
credentials: "include",
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || "Speech synthesis failed");
}
const audioBlob = await response.blob();
const audioUrl = URL.createObjectURL(audioBlob);
audioUrlRef.current = audioUrl;
const audio = new Audio(audioUrl);
audioRef.current = audio;
audio.onended = () => {
setIsPlaying(false);
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
audioUrlRef.current = null;
}
};
audio.onerror = () => {
setError("Audio playback failed");
setIsPlaying(false);
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
audioUrlRef.current = null;
}
};
setIsLoading(false);
setIsPlaying(true);
await audio.play();
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
// Request was cancelled, not an error
return;
}
const message =
err instanceof Error ? err.message : "Speech synthesis failed";
setError(message);
setIsLoading(false);
}
},
[stop]
);
return {
isPlaying,
isLoading,
error,
play,
pause,
stop,
};
}

View File

@@ -0,0 +1,462 @@
import { useState, useRef, useCallback, useEffect } from "react";
// Target format for OpenAI Realtime API
const TARGET_SAMPLE_RATE = 24000;
const CHUNK_INTERVAL_MS = 250;
const SILENCE_TIMEOUT_MS = 5000; // Stop recording if no speech detected for 5 seconds
interface TranscriptMessage {
type: "transcript" | "error";
text?: string;
message?: string;
is_final?: boolean;
}
export interface UseVoiceRecorderOptions {
/** Called when VAD detects silence and final transcript is received */
onFinalTranscript?: (text: string) => void;
/** If true, automatically stop recording when VAD detects silence */
autoStopOnSilence?: boolean;
}
export interface UseVoiceRecorderReturn {
isRecording: boolean;
isProcessing: boolean;
isMuted: boolean;
error: string | null;
liveTranscript: string;
startRecording: () => Promise<void>;
stopRecording: () => Promise<string | null>;
setMuted: (muted: boolean) => void;
}
/**
* Encapsulates all browser resources for a voice recording session.
* Manages WebSocket, Web Audio API, and audio buffering.
*/
class VoiceRecorderSession {
// Browser resources
private websocket: WebSocket | null = null;
private audioContext: AudioContext | null = null;
private scriptNode: ScriptProcessorNode | null = null;
private sourceNode: MediaStreamAudioSourceNode | null = null;
private mediaStream: MediaStream | null = null;
private sendInterval: NodeJS.Timeout | null = null;
// State
private audioBuffer: Float32Array[] = [];
private transcript = "";
private stopResolver: ((text: string | null) => void) | null = null;
private isActive = false;
private silenceTimeout: NodeJS.Timeout | null = null;
// Callbacks to update React state
private onTranscriptChange: (text: string) => void;
private onFinalTranscript: ((text: string) => void) | null;
private onError: (error: string) => void;
private onSilenceTimeout: (() => void) | null;
private onVADStop: (() => void) | null;
private autoStopOnSilence: boolean;
constructor(
onTranscriptChange: (text: string) => void,
onFinalTranscript: ((text: string) => void) | null,
onError: (error: string) => void,
onSilenceTimeout?: () => void,
autoStopOnSilence?: boolean,
onVADStop?: () => void
) {
this.onTranscriptChange = onTranscriptChange;
this.onFinalTranscript = onFinalTranscript;
this.onError = onError;
this.onSilenceTimeout = onSilenceTimeout || null;
this.autoStopOnSilence = autoStopOnSilence ?? false;
this.onVADStop = onVADStop || null;
}
get recording(): boolean {
return this.isActive;
}
get currentTranscript(): string {
return this.transcript;
}
setMuted(muted: boolean): void {
if (this.mediaStream) {
this.mediaStream.getAudioTracks().forEach((track) => {
track.enabled = !muted;
});
}
}
async start(): Promise<void> {
if (this.isActive) return;
this.cleanup();
this.transcript = "";
this.audioBuffer = [];
// Get microphone
this.mediaStream = await navigator.mediaDevices.getUserMedia({
audio: {
channelCount: 1,
sampleRate: { ideal: TARGET_SAMPLE_RATE },
echoCancellation: true,
noiseSuppression: true,
},
});
// Get WS token and connect WebSocket
const wsUrl = await this.getWebSocketUrl();
this.websocket = new WebSocket(wsUrl);
this.websocket.onmessage = this.handleMessage;
this.websocket.onerror = () => this.onError("Connection failed");
this.websocket.onclose = () => {
if (this.stopResolver) {
this.stopResolver(this.transcript || null);
this.stopResolver = null;
}
};
await this.waitForConnection();
// Set up audio capture
this.audioContext = new AudioContext({ sampleRate: TARGET_SAMPLE_RATE });
this.sourceNode = this.audioContext.createMediaStreamSource(
this.mediaStream
);
this.scriptNode = this.audioContext.createScriptProcessor(4096, 1, 1);
this.scriptNode.onaudioprocess = (event) => {
const inputData = event.inputBuffer.getChannelData(0);
this.audioBuffer.push(new Float32Array(inputData));
};
this.sourceNode.connect(this.scriptNode);
this.scriptNode.connect(this.audioContext.destination);
// Start sending audio chunks
this.sendInterval = setInterval(
() => this.sendAudioBuffer(),
CHUNK_INTERVAL_MS
);
this.isActive = true;
}
async stop(): Promise<string | null> {
if (!this.isActive) return this.transcript || null;
// Stop audio capture
if (this.sendInterval) {
clearInterval(this.sendInterval);
this.sendInterval = null;
}
if (this.scriptNode) {
this.scriptNode.disconnect();
this.scriptNode = null;
}
if (this.sourceNode) {
this.sourceNode.disconnect();
this.sourceNode = null;
}
if (this.audioContext) {
this.audioContext.close();
this.audioContext = null;
}
if (this.mediaStream) {
this.mediaStream.getTracks().forEach((track) => track.stop());
this.mediaStream = null;
}
this.audioBuffer = [];
this.isActive = false;
// Get final transcript from server
if (this.websocket?.readyState === WebSocket.OPEN) {
return new Promise((resolve) => {
this.stopResolver = resolve;
this.websocket!.send(JSON.stringify({ type: "end" }));
// Timeout fallback
setTimeout(() => {
if (this.stopResolver) {
this.stopResolver(this.transcript || null);
this.stopResolver = null;
}
}, 3000);
});
}
return this.transcript || null;
}
cleanup(): void {
if (this.sendInterval) clearInterval(this.sendInterval);
if (this.scriptNode) this.scriptNode.disconnect();
if (this.sourceNode) this.sourceNode.disconnect();
if (this.audioContext) this.audioContext.close();
if (this.mediaStream) this.mediaStream.getTracks().forEach((t) => t.stop());
if (this.websocket) this.websocket.close();
this.sendInterval = null;
this.scriptNode = null;
this.sourceNode = null;
this.audioContext = null;
this.mediaStream = null;
this.websocket = null;
this.isActive = false;
}
private async getWebSocketUrl(): Promise<string> {
// Fetch short-lived WS token
const tokenResponse = await fetch("/api/voice/ws-token", {
method: "POST",
credentials: "include",
});
if (!tokenResponse.ok) {
throw new Error("Failed to get WebSocket authentication token");
}
const { token } = await tokenResponse.json();
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const isDev = window.location.port === "3000";
const host = isDev ? "localhost:8080" : window.location.host;
const path = isDev
? "/voice/transcribe/stream"
: "/api/voice/transcribe/stream";
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
}
private waitForConnection(): Promise<void> {
return new Promise((resolve, reject) => {
if (!this.websocket) return reject(new Error("No WebSocket"));
const timeout = setTimeout(
() => reject(new Error("Connection timeout")),
5000
);
this.websocket.onopen = () => {
clearTimeout(timeout);
resolve();
};
this.websocket.onerror = () => {
clearTimeout(timeout);
reject(new Error("Connection failed"));
};
});
}
private handleMessage = (event: MessageEvent): void => {
try {
const data: TranscriptMessage = JSON.parse(event.data);
if (data.type === "transcript") {
if (data.text) {
this.transcript = data.text;
this.onTranscriptChange(data.text);
}
if (data.is_final && data.text) {
// VAD detected silence - trigger callback
if (this.onFinalTranscript) {
this.onFinalTranscript(data.text);
}
// Auto-stop recording if enabled
if (this.autoStopOnSilence) {
// Trigger stop callback to update React state
if (this.onVADStop) {
this.onVADStop();
}
} else {
// If not auto-stopping, reset for next utterance
this.transcript = "";
this.onTranscriptChange("");
this.resetBackendTranscript();
}
// Resolve stop promise if waiting
if (this.stopResolver) {
this.stopResolver(data.text);
this.stopResolver = null;
}
}
} else if (data.type === "error") {
this.onError(data.message || "Transcription error");
}
} catch (e) {
console.error("Failed to parse transcript message:", e);
}
};
private resetBackendTranscript(): void {
if (this.websocket?.readyState === WebSocket.OPEN) {
this.websocket.send(JSON.stringify({ type: "reset" }));
}
}
private sendAudioBuffer(): void {
if (
!this.websocket ||
this.websocket.readyState !== WebSocket.OPEN ||
!this.audioContext ||
this.audioBuffer.length === 0
) {
return;
}
// Concatenate buffered chunks
const totalLength = this.audioBuffer.reduce(
(sum, chunk) => sum + chunk.length,
0
);
// Prevent buffer overflow
if (totalLength > this.audioContext.sampleRate * 0.5 * 2) {
this.audioBuffer = this.audioBuffer.slice(-10);
return;
}
const concatenated = new Float32Array(totalLength);
let offset = 0;
for (const chunk of this.audioBuffer) {
concatenated.set(chunk, offset);
offset += chunk.length;
}
this.audioBuffer = [];
// Resample and convert to PCM16
const resampled = this.resampleAudio(
concatenated,
this.audioContext.sampleRate
);
const pcm16 = this.float32ToInt16(resampled);
this.websocket.send(pcm16.buffer);
}
private resampleAudio(input: Float32Array, inputRate: number): Float32Array {
if (inputRate === TARGET_SAMPLE_RATE) return input;
const ratio = inputRate / TARGET_SAMPLE_RATE;
const outputLength = Math.round(input.length / ratio);
const output = new Float32Array(outputLength);
for (let i = 0; i < outputLength; i++) {
const srcIndex = i * ratio;
const floor = Math.floor(srcIndex);
const ceil = Math.min(floor + 1, input.length - 1);
const fraction = srcIndex - floor;
output[i] = input[floor]! * (1 - fraction) + input[ceil]! * fraction;
}
return output;
}
private float32ToInt16(float32: Float32Array): Int16Array {
const int16 = new Int16Array(float32.length);
for (let i = 0; i < float32.length; i++) {
const s = Math.max(-1, Math.min(1, float32[i]!));
int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
}
return int16;
}
}
/**
* Hook for voice recording with streaming transcription.
*/
export function useVoiceRecorder(
options?: UseVoiceRecorderOptions
): UseVoiceRecorderReturn {
const [isRecording, setIsRecording] = useState(false);
const [isProcessing, setIsProcessing] = useState(false);
const [isMuted, setIsMutedState] = useState(false);
const [error, setError] = useState<string | null>(null);
const [liveTranscript, setLiveTranscript] = useState("");
const sessionRef = useRef<VoiceRecorderSession | null>(null);
const onFinalTranscriptRef = useRef(options?.onFinalTranscript);
const autoStopOnSilenceRef = useRef(options?.autoStopOnSilence ?? true); // Default to true
// Keep callback ref in sync
useEffect(() => {
onFinalTranscriptRef.current = options?.onFinalTranscript;
autoStopOnSilenceRef.current = options?.autoStopOnSilence ?? true;
}, [options?.onFinalTranscript, options?.autoStopOnSilence]);
// Cleanup on unmount
useEffect(() => {
return () => {
sessionRef.current?.cleanup();
};
}, []);
const startRecording = useCallback(async () => {
if (sessionRef.current?.recording) return;
setError(null);
setLiveTranscript("");
// Create VAD stop handler that will stop the session
const handleVADStop = () => {
if (sessionRef.current) {
sessionRef.current.stop().then(() => {
setIsRecording(false);
});
}
};
sessionRef.current = new VoiceRecorderSession(
setLiveTranscript,
(text) => onFinalTranscriptRef.current?.(text),
setError,
undefined, // onSilenceTimeout
autoStopOnSilenceRef.current,
handleVADStop
);
try {
await sessionRef.current.start();
setIsRecording(true);
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to start recording"
);
throw err;
}
}, []);
const stopRecording = useCallback(async (): Promise<string | null> => {
if (!sessionRef.current) return null;
setIsProcessing(true);
try {
const transcript = await sessionRef.current.stop();
return transcript;
} finally {
setIsRecording(false);
setIsProcessing(false);
setIsMutedState(false); // Reset mute state when recording stops
}
}, []);
const setMuted = useCallback((muted: boolean) => {
setIsMutedState(muted);
sessionRef.current?.setMuted(muted);
}, []);
return {
isRecording,
isProcessing,
isMuted,
error,
liveTranscript,
startRecording,
stopRecording,
setMuted,
};
}

View File

@@ -0,0 +1,149 @@
import { useState, useRef, useCallback, useEffect } from "react";
export type WebSocketStatus =
| "connecting"
| "connected"
| "disconnected"
| "error";
export interface UseWebSocketOptions<T> {
/** URL to connect to */
url: string;
/** Called when a message is received */
onMessage?: (data: T) => void;
/** Called when connection opens */
onOpen?: () => void;
/** Called when connection closes */
onClose?: () => void;
/** Called on error */
onError?: (error: Event) => void;
/** Auto-connect on mount */
autoConnect?: boolean;
}
export interface UseWebSocketReturn<T> {
/** Current connection status */
status: WebSocketStatus;
/** Send JSON data */
sendJson: (data: T) => void;
/** Send binary data */
sendBinary: (data: Blob | ArrayBuffer) => void;
/** Connect to WebSocket */
connect: () => Promise<void>;
/** Disconnect from WebSocket */
disconnect: () => void;
}
export function useWebSocket<TReceive = unknown, TSend = unknown>({
url,
onMessage,
onOpen,
onClose,
onError,
autoConnect = false,
}: UseWebSocketOptions<TReceive>): UseWebSocketReturn<TSend> {
const [status, setStatus] = useState<WebSocketStatus>("disconnected");
const wsRef = useRef<WebSocket | null>(null);
const onMessageRef = useRef(onMessage);
const onOpenRef = useRef(onOpen);
const onCloseRef = useRef(onClose);
const onErrorRef = useRef(onError);
// Keep refs updated
useEffect(() => {
onMessageRef.current = onMessage;
onOpenRef.current = onOpen;
onCloseRef.current = onClose;
onErrorRef.current = onError;
}, [onMessage, onOpen, onClose, onError]);
const connect = useCallback(async (): Promise<void> => {
if (
wsRef.current?.readyState === WebSocket.OPEN ||
wsRef.current?.readyState === WebSocket.CONNECTING
) {
return;
}
setStatus("connecting");
return new Promise((resolve, reject) => {
const ws = new WebSocket(url);
wsRef.current = ws;
const timeout = setTimeout(() => {
ws.close();
reject(new Error("WebSocket connection timeout"));
}, 10000);
ws.onopen = () => {
clearTimeout(timeout);
setStatus("connected");
onOpenRef.current?.();
resolve();
};
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data) as TReceive;
onMessageRef.current?.(data);
} catch {
// Non-JSON message, ignore or handle differently
}
};
ws.onclose = () => {
setStatus("disconnected");
onCloseRef.current?.();
wsRef.current = null;
};
ws.onerror = (error) => {
clearTimeout(timeout);
setStatus("error");
onErrorRef.current?.(error);
reject(new Error("WebSocket connection failed"));
};
});
}, [url]);
const disconnect = useCallback(() => {
if (wsRef.current) {
wsRef.current.close();
wsRef.current = null;
}
setStatus("disconnected");
}, []);
const sendJson = useCallback((data: TSend) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify(data));
}
}, []);
const sendBinary = useCallback((data: Blob | ArrayBuffer) => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(data);
}
}, []);
// Auto-connect if enabled
useEffect(() => {
if (autoConnect) {
connect().catch(() => {
// Error handled via onError callback
});
}
return () => {
disconnect();
};
}, [autoConnect, connect, disconnect]);
return {
status,
sendJson,
sendBinary,
connect,
disconnect,
};
}

View File

@@ -86,7 +86,7 @@ function SettingsRoot({ width = "md", ...props }: SettingsRootProps) {
return (
<div
id="page-wrapper-scroll-container"
className="w-full h-full flex flex-col items-center overflow-y-auto pt-10"
className="w-full h-full flex flex-col items-center overflow-y-auto"
>
{/* WARNING: The id="page-wrapper-scroll-container" above is used by SettingsHeader
to detect scroll position and show/hide the scroll shadow.

View File

@@ -58,7 +58,6 @@ export const ADMIN_PATHS = {
DOCUMENT_PROCESSING: "/admin/configuration/document-processing",
KNOWLEDGE_GRAPH: "/admin/kg",
USERS: "/admin/users",
USERS_V2: "/admin/users2",
API_KEYS: "/admin/api-key",
TOKEN_RATE_LIMITS: "/admin/token-rate-limits",
USAGE: "/admin/performance/usage",
@@ -191,11 +190,6 @@ export const ADMIN_ROUTE_CONFIG: Record<string, AdminRouteConfig> = {
title: "Manage Users",
sidebarLabel: "Users",
},
[ADMIN_PATHS.USERS_V2]: {
icon: SvgUser,
title: "Users & Requests",
sidebarLabel: "Users v2",
},
[ADMIN_PATHS.API_KEYS]: {
icon: SvgKey,
title: "API Keys",

View File

@@ -0,0 +1,37 @@
/**
* Sentence detection for streaming TTS using 'sbd' library.
*/
import sbd from "sbd";
/**
* Split text into sentences. Returns complete sentences and remaining buffer.
*/
export function detectSentences(
text: string,
isComplete: boolean = false
): { sentences: string[]; buffer: string } {
const sentences = sbd.sentences(text, {
newline_boundaries: true,
html_boundaries: false,
sanitize: false,
});
if (sentences.length === 0) {
return { sentences: [], buffer: text };
}
// Check if text ends with sentence-ending punctuation
const endsWithPunctuation = /[.!?]["']?\s*$/.test(text.trim());
if (isComplete || endsWithPunctuation) {
// All sentences are complete
return { sentences: sentences.map((s) => s.trim()), buffer: "" };
}
// Last sentence might be incomplete - keep it in buffer
const complete = sentences.slice(0, -1).map((s) => s.trim());
const buffer = sentences[sentences.length - 1] ?? "";
return { sentences: complete, buffer };
}

588
web/src/lib/streamingTTS.ts Normal file
View File

@@ -0,0 +1,588 @@
/**
* Real-time streaming TTS using HTTP streaming with MediaSource Extensions.
* Plays audio chunks as they arrive for smooth, low-latency playback.
*/
/**
* HTTPStreamingTTSPlayer - Uses HTTP streaming with MediaSource Extensions
* for smooth, gapless audio playback. This is the recommended approach for
* real-time TTS as it properly handles MP3 frame boundaries.
*/
export class HTTPStreamingTTSPlayer {
private mediaSource: MediaSource | null = null;
private sourceBuffer: SourceBuffer | null = null;
private audioElement: HTMLAudioElement | null = null;
private pendingChunks: Uint8Array[] = [];
private isAppending: boolean = false;
private isPlaying: boolean = false;
private streamComplete: boolean = false;
private onPlayingChange?: (playing: boolean) => void;
private onError?: (error: string) => void;
private abortController: AbortController | null = null;
constructor(options?: {
onPlayingChange?: (playing: boolean) => void;
onError?: (error: string) => void;
}) {
this.onPlayingChange = options?.onPlayingChange;
this.onError = options?.onError;
}
private getAPIUrl(): string {
// Always go through the frontend proxy to ensure cookies are sent correctly
// The Next.js proxy at /api/* forwards to the backend
return "/api/voice/synthesize";
}
/**
* Speak text using HTTP streaming with real-time playback.
* Audio begins playing as soon as the first chunks arrive.
*/
async speak(
text: string,
voice?: string,
speed: number = 1.0
): Promise<void> {
// Cleanup any previous playback
this.cleanup();
// Create abort controller for this request
this.abortController = new AbortController();
// Build URL with query params
const params = new URLSearchParams();
params.set("text", text);
if (voice) params.set("voice", voice);
params.set("speed", speed.toString());
const url = `${this.getAPIUrl()}?${params}`;
// Check if MediaSource is supported
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
// Fallback to simple buffered playback
return this.fallbackSpeak(url);
}
// Create MediaSource and audio element
this.mediaSource = new MediaSource();
this.audioElement = new Audio();
this.audioElement.src = URL.createObjectURL(this.mediaSource);
// Set up audio element event handlers
this.audioElement.onplay = () => {
if (!this.isPlaying) {
this.isPlaying = true;
this.onPlayingChange?.(true);
}
};
this.audioElement.onended = () => {
this.isPlaying = false;
this.onPlayingChange?.(false);
};
this.audioElement.onerror = () => {
this.onError?.("Audio playback error");
this.isPlaying = false;
this.onPlayingChange?.(false);
};
// Wait for MediaSource to be ready
await new Promise<void>((resolve, reject) => {
if (!this.mediaSource) {
reject(new Error("MediaSource not initialized"));
return;
}
this.mediaSource.onsourceopen = () => {
try {
// Create SourceBuffer for MP3
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
this.sourceBuffer.mode = "sequence";
this.sourceBuffer.onupdateend = () => {
this.isAppending = false;
this.processNextChunk();
};
resolve();
} catch (err) {
reject(err);
}
};
// MediaSource doesn't have onerror in all browsers, use onsourceclose as fallback
this.mediaSource.onsourceclose = () => {
if (this.mediaSource?.readyState === "closed") {
reject(new Error("MediaSource closed unexpectedly"));
}
};
});
// Start fetching and streaming audio
try {
const response = await fetch(url, {
method: "POST",
signal: this.abortController.signal,
credentials: "include", // Include cookies for authentication
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`TTS request failed: ${response.status} - ${errorText}`
);
}
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
// Start playback as soon as we have some data
let firstChunk = true;
while (true) {
const { done, value } = await reader.read();
if (done) {
this.streamComplete = true;
// End the stream when all chunks are appended
this.finalizeStream();
break;
}
if (value) {
this.pendingChunks.push(value);
this.processNextChunk();
// Start playback after first chunk
if (firstChunk && this.audioElement) {
firstChunk = false;
// Small delay to buffer a bit before starting
setTimeout(() => {
this.audioElement?.play().catch(() => {
// Ignore playback start errors
});
}, 100);
}
}
}
} catch (err) {
if (err instanceof Error && err.name === "AbortError") {
return;
}
this.onError?.(err instanceof Error ? err.message : "TTS error");
throw err;
}
}
/**
* Process next chunk from the queue.
*/
private processNextChunk(): void {
if (
this.isAppending ||
this.pendingChunks.length === 0 ||
!this.sourceBuffer ||
this.sourceBuffer.updating
) {
return;
}
const chunk = this.pendingChunks.shift();
if (chunk) {
this.isAppending = true;
try {
// Use ArrayBuffer directly for better TypeScript compatibility
const buffer = chunk.buffer.slice(
chunk.byteOffset,
chunk.byteOffset + chunk.byteLength
) as ArrayBuffer;
this.sourceBuffer.appendBuffer(buffer);
} catch {
this.isAppending = false;
// Try next chunk
this.processNextChunk();
}
}
}
/**
* Finalize the stream when all data has been received.
*/
private finalizeStream(): void {
if (this.pendingChunks.length > 0 || this.isAppending) {
// Wait for remaining chunks to be appended
setTimeout(() => this.finalizeStream(), 50);
return;
}
if (
this.mediaSource &&
this.mediaSource.readyState === "open" &&
this.sourceBuffer &&
!this.sourceBuffer.updating
) {
try {
this.mediaSource.endOfStream();
} catch {
// Ignore errors when ending stream
}
}
}
/**
* Fallback for browsers that don't support MediaSource Extensions.
* Buffers all audio before playing.
*/
private async fallbackSpeak(url: string): Promise<void> {
const response = await fetch(url, {
method: "POST",
signal: this.abortController?.signal,
credentials: "include", // Include cookies for authentication
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
}
const audioData = await response.arrayBuffer();
const blob = new Blob([audioData], { type: "audio/mpeg" });
const audioUrl = URL.createObjectURL(blob);
this.audioElement = new Audio(audioUrl);
this.audioElement.onplay = () => {
this.isPlaying = true;
this.onPlayingChange?.(true);
};
this.audioElement.onended = () => {
this.isPlaying = false;
this.onPlayingChange?.(false);
URL.revokeObjectURL(audioUrl);
};
this.audioElement.onerror = () => {
this.onError?.("Audio playback error");
};
await this.audioElement.play();
}
/**
* Stop playback and cleanup resources.
*/
stop(): void {
// Abort any ongoing request
if (this.abortController) {
this.abortController.abort();
this.abortController = null;
}
this.cleanup();
}
/**
* Cleanup all resources.
*/
private cleanup(): void {
// Stop and cleanup audio element
if (this.audioElement) {
this.audioElement.pause();
this.audioElement.src = "";
this.audioElement = null;
}
// Cleanup MediaSource
if (this.mediaSource && this.mediaSource.readyState === "open") {
try {
if (this.sourceBuffer) {
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
}
this.mediaSource.endOfStream();
} catch {
// Ignore cleanup errors
}
}
this.mediaSource = null;
this.sourceBuffer = null;
this.pendingChunks = [];
this.isAppending = false;
this.streamComplete = false;
if (this.isPlaying) {
this.isPlaying = false;
this.onPlayingChange?.(false);
}
}
get playing(): boolean {
return this.isPlaying;
}
}
/**
* WebSocketStreamingTTSPlayer - Uses WebSocket for bidirectional streaming.
* Useful for scenarios where you want to stream text in and get audio out
* incrementally (e.g., as LLM generates text).
*/
export class WebSocketStreamingTTSPlayer {
private websocket: WebSocket | null = null;
private mediaSource: MediaSource | null = null;
private sourceBuffer: SourceBuffer | null = null;
private audioElement: HTMLAudioElement | null = null;
private pendingChunks: Uint8Array[] = [];
private isAppending: boolean = false;
private isPlaying: boolean = false;
private onPlayingChange?: (playing: boolean) => void;
private onError?: (error: string) => void;
private hasStartedPlayback: boolean = false;
constructor(options?: {
onPlayingChange?: (playing: boolean) => void;
onError?: (error: string) => void;
}) {
this.onPlayingChange = options?.onPlayingChange;
this.onError = options?.onError;
}
private async getWebSocketUrl(): Promise<string> {
// Fetch short-lived WS token
const tokenResponse = await fetch("/api/voice/ws-token", {
method: "POST",
credentials: "include",
});
if (!tokenResponse.ok) {
throw new Error("Failed to get WebSocket authentication token");
}
const { token } = await tokenResponse.json();
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const isDev = window.location.port === "3000";
const host = isDev ? "localhost:8080" : window.location.host;
const path = isDev
? "/voice/synthesize/stream"
: "/api/voice/synthesize/stream";
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
}
async connect(voice?: string, speed?: number): Promise<void> {
// Cleanup any previous connection
this.cleanup();
// Check MediaSource support
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
throw new Error("MediaSource Extensions not supported");
}
// Create MediaSource and audio element
this.mediaSource = new MediaSource();
this.audioElement = new Audio();
this.audioElement.src = URL.createObjectURL(this.mediaSource);
this.audioElement.onplay = () => {
if (!this.isPlaying) {
this.isPlaying = true;
this.onPlayingChange?.(true);
}
};
this.audioElement.onended = () => {
this.isPlaying = false;
this.onPlayingChange?.(false);
};
// Wait for MediaSource to be ready
await new Promise<void>((resolve, reject) => {
this.mediaSource!.onsourceopen = () => {
try {
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
this.sourceBuffer.mode = "sequence";
this.sourceBuffer.onupdateend = () => {
this.isAppending = false;
this.processNextChunk();
};
resolve();
} catch (err) {
reject(err);
}
};
});
// Connect WebSocket
const url = await this.getWebSocketUrl();
return new Promise((resolve, reject) => {
this.websocket = new WebSocket(url);
this.websocket.onopen = () => {
// Send initial config
this.websocket?.send(
JSON.stringify({
type: "config",
voice: voice,
speed: speed || 1.0,
})
);
resolve();
};
this.websocket.onerror = () => {
reject(new Error("WebSocket connection failed"));
};
this.websocket.onmessage = async (event) => {
if (event.data instanceof Blob) {
// Audio chunk received
const arrayBuffer = await event.data.arrayBuffer();
this.pendingChunks.push(new Uint8Array(arrayBuffer));
this.processNextChunk();
// Start playback after first chunk
if (!this.hasStartedPlayback && this.audioElement) {
this.hasStartedPlayback = true;
setTimeout(() => {
this.audioElement?.play().catch(() => {
// Ignore playback errors
});
}, 100);
}
} else {
// JSON message
try {
const data = JSON.parse(event.data);
if (data.type === "audio_done") {
this.finalizeStream();
} else if (data.type === "error") {
this.onError?.(data.message);
}
} catch {
// Ignore parse errors
}
}
};
this.websocket.onclose = () => {
this.finalizeStream();
};
});
}
private processNextChunk(): void {
if (
this.isAppending ||
this.pendingChunks.length === 0 ||
!this.sourceBuffer ||
this.sourceBuffer.updating
) {
return;
}
const chunk = this.pendingChunks.shift();
if (chunk) {
this.isAppending = true;
try {
// Use ArrayBuffer directly for better TypeScript compatibility
const buffer = chunk.buffer.slice(
chunk.byteOffset,
chunk.byteOffset + chunk.byteLength
) as ArrayBuffer;
this.sourceBuffer.appendBuffer(buffer);
} catch {
this.isAppending = false;
this.processNextChunk();
}
}
}
private finalizeStream(): void {
if (this.pendingChunks.length > 0 || this.isAppending) {
setTimeout(() => this.finalizeStream(), 50);
return;
}
if (
this.mediaSource &&
this.mediaSource.readyState === "open" &&
this.sourceBuffer &&
!this.sourceBuffer.updating
) {
try {
this.mediaSource.endOfStream();
} catch {
// Ignore
}
}
}
async speak(text: string): Promise<void> {
if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
throw new Error("WebSocket not connected");
}
this.websocket.send(
JSON.stringify({
type: "synthesize",
text: text,
})
);
}
stop(): void {
this.cleanup();
}
disconnect(): void {
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
this.websocket.send(JSON.stringify({ type: "end" }));
this.websocket.close();
}
this.cleanup();
}
private cleanup(): void {
if (this.websocket) {
this.websocket.close();
this.websocket = null;
}
if (this.audioElement) {
this.audioElement.pause();
this.audioElement.src = "";
this.audioElement = null;
}
if (this.mediaSource && this.mediaSource.readyState === "open") {
try {
if (this.sourceBuffer) {
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
}
this.mediaSource.endOfStream();
} catch {
// Ignore
}
}
this.mediaSource = null;
this.sourceBuffer = null;
this.pendingChunks = [];
this.isAppending = false;
this.hasStartedPlayback = false;
if (this.isPlaying) {
this.isPlaying = false;
this.onPlayingChange?.(false);
}
}
get playing(): boolean {
return this.isPlaying;
}
}
// Export the HTTP player as the default/recommended option
export { HTTPStreamingTTSPlayer as StreamingTTSPlayer };

View File

@@ -32,6 +32,11 @@ interface UserPreferences {
theme_preference: ThemePreference | null;
chat_background: string | null;
default_app_mode: "AUTO" | "CHAT" | "SEARCH";
// Voice preferences
voice_auto_send?: boolean;
voice_auto_playback?: boolean;
voice_playback_speed?: number;
preferred_voice?: string;
}
export interface MemoryItem {

View File

@@ -46,6 +46,12 @@ interface UserContextType {
updateUserChatBackground: (chatBackground: string | null) => Promise<void>;
updateUserDefaultModel: (defaultModel: string | null) => Promise<void>;
updateUserDefaultAppMode: (mode: "CHAT" | "SEARCH") => Promise<void>;
updateUserVoiceSettings: (settings: {
auto_send?: boolean;
auto_playback?: boolean;
playback_speed?: number;
preferred_voice?: string;
}) => Promise<void>;
}
const UserContext = createContext<UserContextType | undefined>(undefined);
@@ -460,6 +466,54 @@ export function UserProvider({
}
};
const updateUserVoiceSettings = async (settings: {
auto_send?: boolean;
auto_playback?: boolean;
playback_speed?: number;
preferred_voice?: string;
}) => {
try {
setUpToDateUser((prevUser) => {
if (prevUser) {
return {
...prevUser,
preferences: {
...prevUser.preferences,
voice_auto_send:
settings.auto_send ?? prevUser.preferences.voice_auto_send,
voice_auto_playback:
settings.auto_playback ??
prevUser.preferences.voice_auto_playback,
voice_playback_speed:
settings.playback_speed ??
prevUser.preferences.voice_playback_speed,
preferred_voice:
settings.preferred_voice ??
prevUser.preferences.preferred_voice,
},
};
}
return prevUser;
});
const response = await fetch("/api/voice/settings", {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(settings),
});
if (!response.ok) {
await refreshUser();
throw new Error("Failed to update voice settings");
}
} catch (error) {
console.error("Error updating voice settings:", error);
throw error;
}
};
const refreshUser = async () => {
await fetchUser();
};
@@ -478,6 +532,7 @@ export function UserProvider({
updateUserChatBackground,
updateUserDefaultModel,
updateUserDefaultAppMode,
updateUserVoiceSettings,
toggleAgentPinnedStatus,
isAdmin: upToDateUser?.role === UserRole.ADMIN,
// Curator status applies for either global or basic curator

View File

@@ -0,0 +1,771 @@
"use client";
import React, {
createContext,
useContext,
useState,
useCallback,
useRef,
useEffect,
} from "react";
import { useUser } from "@/providers/UserProvider";
interface VoiceModeContextType {
/** Whether TTS audio is currently playing */
isTTSPlaying: boolean;
/** Whether TTS is loading/generating audio */
isTTSLoading: boolean;
/** Text that has been spoken so far (for synced display) */
spokenText: string;
/** Stream text for TTS - speaks sentences as they complete */
streamTTS: (text: string, isComplete?: boolean) => void;
/** Stop TTS playback */
stopTTS: (options?: { manual?: boolean }) => void;
/** Increments when TTS is manually stopped by the user */
manualStopCount: number;
/** Reset state for new message */
resetTTS: () => void;
}
const VoiceModeContext = createContext<VoiceModeContextType | null>(null);
/**
* Clean text for TTS - remove markdown formatting
*/
function cleanTextForTTS(text: string): string {
return text
.replace(/\*\*/g, "") // Remove bold markers
.replace(/\*/g, "") // Remove italic markers
.replace(/`{1,3}/g, "") // Remove code markers
.replace(/#{1,6}\s*/g, "") // Remove headers
.replace(/\[([^\]]+)\]\([^)]+\)/g, "$1") // Convert links to just text
.replace(/\n+/g, " ") // Replace newlines with spaces
.replace(/\s+/g, " ") // Normalize whitespace
.trim();
}
/**
* Find the next natural chunk boundary in text.
* Prefers sentence endings for natural speech rhythm.
*/
function findChunkBoundary(text: string): number {
// Look for sentence endings (. ! ?) - these are natural speech breaks
const sentenceRegex = /[.!?](?:\s|$)/g;
let match;
let lastSentenceEnd = -1;
while ((match = sentenceRegex.exec(text)) !== null) {
const endPos = match.index + 1;
if (endPos >= 10) {
lastSentenceEnd = endPos;
if (endPos >= 30) return endPos;
}
}
if (lastSentenceEnd > 0) return lastSentenceEnd;
// Only break at clauses for very long text (150+ chars)
if (text.length >= 150) {
const clauseRegex = /[,;:]\s/g;
while ((match = clauseRegex.exec(text)) !== null) {
const endPos = match.index + 1;
if (endPos >= 70) return endPos;
}
}
// Break at word boundary for extremely long text (200+ chars)
if (text.length >= 200) {
const spaceIndex = text.lastIndexOf(" ", 120);
if (spaceIndex > 80) return spaceIndex;
}
return -1;
}
export function VoiceModeProvider({ children }: { children: React.ReactNode }) {
const { user } = useUser();
const autoPlayback = user?.preferences?.voice_auto_playback ?? false;
const playbackSpeed = user?.preferences?.voice_playback_speed ?? 1.0;
const preferredVoice = user?.preferences?.preferred_voice;
const [isTTSPlaying, setIsTTSPlaying] = useState(false);
const [isTTSLoading, setIsTTSLoading] = useState(false);
const [spokenText, setSpokenText] = useState("");
const [manualStopCount, setManualStopCount] = useState(0);
// WebSocket and audio state
const wsRef = useRef<WebSocket | null>(null);
const mediaSourceRef = useRef<MediaSource | null>(null);
const sourceBufferRef = useRef<SourceBuffer | null>(null);
const audioElementRef = useRef<HTMLAudioElement | null>(null);
const audioUrlRef = useRef<string | null>(null);
const pendingChunksRef = useRef<Uint8Array[]>([]);
const isAppendingRef = useRef(false);
const isPlayingRef = useRef(false);
const hasStartedPlaybackRef = useRef(false);
// Text tracking
const committedPositionRef = useRef(0);
const lastRawTextRef = useRef("");
const pendingTextRef = useRef<string[]>([]);
const isConnectingRef = useRef(false);
// Timers
const flushTimerRef = useRef<NodeJS.Timeout | null>(null);
const fastStartTimerRef = useRef<NodeJS.Timeout | null>(null);
const loadingTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const endCheckIntervalRef = useRef<NodeJS.Timeout | null>(null);
const hasSpokenFirstChunkRef = useRef(false);
const hasSignaledEndRef = useRef(false);
const streamEndedRef = useRef(false);
// Process next chunk from the pending queue
const processNextChunk = useCallback(() => {
if (
isAppendingRef.current ||
pendingChunksRef.current.length === 0 ||
!sourceBufferRef.current ||
sourceBufferRef.current.updating
) {
return;
}
const chunk = pendingChunksRef.current.shift();
if (chunk) {
isAppendingRef.current = true;
try {
const buffer = chunk.buffer.slice(
chunk.byteOffset,
chunk.byteOffset + chunk.byteLength
) as ArrayBuffer;
sourceBufferRef.current.appendBuffer(buffer);
} catch {
isAppendingRef.current = false;
processNextChunk();
}
}
}, []);
// Finalize the media stream when done
const finalizeStream = useCallback(() => {
if (pendingChunksRef.current.length > 0 || isAppendingRef.current) {
setTimeout(() => finalizeStream(), 50);
return;
}
streamEndedRef.current = true;
if (
mediaSourceRef.current &&
mediaSourceRef.current.readyState === "open" &&
sourceBufferRef.current &&
!sourceBufferRef.current.updating
) {
try {
mediaSourceRef.current.endOfStream();
} catch {
// Ignore errors when ending stream
}
}
// Clear any existing end check interval
if (endCheckIntervalRef.current) {
clearInterval(endCheckIntervalRef.current);
endCheckIntervalRef.current = null;
}
// More aggressive end detection: check every 200ms if audio has ended
// This handles cases where onended event doesn't fire with MediaSource
endCheckIntervalRef.current = setInterval(() => {
const audioEl = audioElementRef.current;
// If audio element is gone or stream was reset, clean up
if (!audioEl || !streamEndedRef.current) {
if (endCheckIntervalRef.current) {
clearInterval(endCheckIntervalRef.current);
endCheckIntervalRef.current = null;
}
return;
}
// Check if audio has ended (either via ended property or by reaching duration)
const hasEnded =
audioEl.ended ||
audioEl.paused ||
(audioEl.duration > 0 &&
audioEl.currentTime >= audioEl.duration - 0.1) ||
audioEl.readyState === 0;
if (hasEnded && isPlayingRef.current) {
isPlayingRef.current = false;
setIsTTSPlaying(false);
if (endCheckIntervalRef.current) {
clearInterval(endCheckIntervalRef.current);
endCheckIntervalRef.current = null;
}
}
}, 200);
// Fallback: if audio doesn't finish playing within 10s after stream ends,
// reset the playing state to prevent mic button from being stuck disabled
setTimeout(() => {
if (endCheckIntervalRef.current) {
clearInterval(endCheckIntervalRef.current);
endCheckIntervalRef.current = null;
}
if (isPlayingRef.current) {
isPlayingRef.current = false;
setIsTTSPlaying(false);
}
}, 10000);
}, []);
// Initialize MediaSource for streaming audio
const initMediaSource = useCallback(async () => {
// Check if MediaSource is supported
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
return false;
}
// Create MediaSource and audio element
mediaSourceRef.current = new MediaSource();
audioElementRef.current = new Audio();
audioUrlRef.current = URL.createObjectURL(mediaSourceRef.current);
audioElementRef.current.src = audioUrlRef.current;
audioElementRef.current.onplay = () => {
if (!isPlayingRef.current) {
isPlayingRef.current = true;
setIsTTSPlaying(true);
}
};
audioElementRef.current.onended = () => {
isPlayingRef.current = false;
setIsTTSPlaying(false);
};
audioElementRef.current.onerror = () => {
isPlayingRef.current = false;
setIsTTSPlaying(false);
};
// Wait for MediaSource to be ready
await new Promise<void>((resolve, reject) => {
if (!mediaSourceRef.current) {
reject(new Error("MediaSource not initialized"));
return;
}
mediaSourceRef.current.onsourceopen = () => {
try {
sourceBufferRef.current =
mediaSourceRef.current!.addSourceBuffer("audio/mpeg");
sourceBufferRef.current.mode = "sequence";
sourceBufferRef.current.onupdateend = () => {
isAppendingRef.current = false;
processNextChunk();
};
resolve();
} catch (err) {
reject(err);
}
};
mediaSourceRef.current.onsourceclose = () => {
if (mediaSourceRef.current?.readyState === "closed") {
reject(new Error("MediaSource closed unexpectedly"));
}
};
});
return true;
}, [processNextChunk]);
// Handle incoming audio data from WebSocket
const handleAudioData = useCallback(
async (data: ArrayBuffer) => {
pendingChunksRef.current.push(new Uint8Array(data));
processNextChunk();
// Start playback after first chunk
if (!hasStartedPlaybackRef.current && audioElementRef.current) {
// Small delay to buffer a bit before starting
setTimeout(() => {
const audioEl = audioElementRef.current;
if (!audioEl || hasStartedPlaybackRef.current) {
return;
}
audioEl
.play()
.then(() => {
hasStartedPlaybackRef.current = true;
})
.catch(() => {
// Keep hasStartedPlaybackRef as false so we retry on next audio chunk.
});
}, 100);
}
},
[processNextChunk]
);
// Get WebSocket URL for TTS with authentication token
const getWebSocketUrl = useCallback(async () => {
// Fetch short-lived WS token
const tokenResponse = await fetch("/api/voice/ws-token", {
method: "POST",
credentials: "include",
});
if (!tokenResponse.ok) {
throw new Error("Failed to get WebSocket authentication token");
}
const { token } = await tokenResponse.json();
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const isDev = window.location.port === "3000";
const host = isDev ? "localhost:8080" : window.location.host;
const path = isDev
? "/voice/synthesize/stream"
: "/api/voice/synthesize/stream";
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
}, []);
// Connect to WebSocket TTS
const connectWebSocket = useCallback(async () => {
// Skip if already connected, connecting, or in the process of connecting
if (
wsRef.current?.readyState === WebSocket.OPEN ||
wsRef.current?.readyState === WebSocket.CONNECTING ||
isConnectingRef.current
) {
return;
}
// Set connecting flag to prevent concurrent connection attempts
isConnectingRef.current = true;
try {
// Initialize MediaSource first
const initialized = await initMediaSource();
if (!initialized) {
isConnectingRef.current = false;
return;
}
// Get WebSocket URL with auth token
const wsUrl = await getWebSocketUrl();
const ws = new WebSocket(wsUrl);
ws.onopen = () => {
isConnectingRef.current = false;
// Send initial config
ws.send(
JSON.stringify({
type: "config",
voice: preferredVoice,
speed: playbackSpeed,
})
);
// Send any pending text
for (const text of pendingTextRef.current) {
ws.send(JSON.stringify({ type: "synthesize", text }));
}
pendingTextRef.current = [];
};
ws.onmessage = async (event) => {
if (event.data instanceof Blob) {
const arrayBuffer = await event.data.arrayBuffer();
handleAudioData(arrayBuffer);
} else if (typeof event.data === "string") {
try {
const msg = JSON.parse(event.data);
if (msg.type === "audio_done") {
if (loadingTimeoutRef.current) {
clearTimeout(loadingTimeoutRef.current);
loadingTimeoutRef.current = null;
}
setIsTTSLoading(false);
finalizeStream();
}
} catch {
// Ignore parse errors
}
}
};
ws.onerror = () => {
isConnectingRef.current = false;
setIsTTSLoading(false);
};
ws.onclose = () => {
wsRef.current = null;
isConnectingRef.current = false;
setIsTTSLoading(false);
finalizeStream();
};
wsRef.current = ws;
} catch {
isConnectingRef.current = false;
}
}, [
preferredVoice,
playbackSpeed,
handleAudioData,
getWebSocketUrl,
initMediaSource,
finalizeStream,
]);
// Send text to TTS via WebSocket
const sendTextToTTS = useCallback(
(text: string) => {
if (!text.trim()) return;
setIsTTSLoading(true);
setSpokenText((prev) => (prev ? prev + " " + text : text));
// Set a timeout to reset loading state if TTS doesn't complete
if (loadingTimeoutRef.current) {
clearTimeout(loadingTimeoutRef.current);
}
loadingTimeoutRef.current = setTimeout(() => {
setIsTTSLoading(false);
setIsTTSPlaying(false);
}, 60000);
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: "synthesize", text }));
} else {
pendingTextRef.current.push(text);
connectWebSocket();
}
},
[connectWebSocket]
);
const streamTTS = useCallback(
(text: string, isComplete: boolean = false) => {
if (!autoPlayback) {
return;
}
// Skip if text hasn't changed
if (text === lastRawTextRef.current && !isComplete) return;
lastRawTextRef.current = text;
// Clear pending timers
if (flushTimerRef.current) {
clearTimeout(flushTimerRef.current);
flushTimerRef.current = null;
}
if (fastStartTimerRef.current) {
clearTimeout(fastStartTimerRef.current);
fastStartTimerRef.current = null;
}
// Clean the full text
const cleanedText = cleanTextForTTS(text);
const uncommittedText = cleanedText.slice(committedPositionRef.current);
// On completion, we must still signal "end" even if there's no new text.
// Otherwise ElevenLabs waits for more input and eventually times out.
if (uncommittedText.length === 0) {
if (isComplete && !hasSignaledEndRef.current) {
hasSignaledEndRef.current = true;
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: "end" }));
} else {
const sendEnd = () => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
if (pendingTextRef.current.length === 0) {
wsRef.current.send(JSON.stringify({ type: "end" }));
} else {
setTimeout(sendEnd, 100);
}
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
setTimeout(sendEnd, 100);
}
};
setTimeout(sendEnd, 100);
}
}
return;
}
// Find chunk boundaries and send immediately
let remaining = uncommittedText;
let offset = 0;
while (remaining.length > 0) {
const boundaryIndex = findChunkBoundary(remaining);
if (boundaryIndex > 0) {
const chunkText = remaining.slice(0, boundaryIndex).trim();
if (chunkText.length > 0) {
sendTextToTTS(chunkText);
hasSpokenFirstChunkRef.current = true;
}
offset += boundaryIndex;
remaining = remaining.slice(boundaryIndex).trim();
} else {
break;
}
}
committedPositionRef.current += offset;
// Handle remaining text when stream is complete
if (isComplete && remaining.trim().length > 0) {
sendTextToTTS(remaining.trim());
committedPositionRef.current = cleanedText.length;
hasSpokenFirstChunkRef.current = true;
}
// When streaming is complete, signal end to flush remaining audio
if (isComplete && !hasSignaledEndRef.current) {
hasSignaledEndRef.current = true;
if (wsRef.current?.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify({ type: "end" }));
} else {
const sendEnd = () => {
if (wsRef.current?.readyState === WebSocket.OPEN) {
if (pendingTextRef.current.length === 0) {
wsRef.current.send(JSON.stringify({ type: "end" }));
} else {
setTimeout(sendEnd, 100);
}
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
setTimeout(sendEnd, 100);
}
};
setTimeout(sendEnd, 100);
}
}
const currentUncommitted = cleanedText
.slice(committedPositionRef.current)
.trim();
// Fast start: if we haven't spoken yet and have 20+ chars, send after 200ms
if (
!hasSpokenFirstChunkRef.current &&
currentUncommitted.length >= 20 &&
!isComplete
) {
fastStartTimerRef.current = setTimeout(() => {
if (hasSpokenFirstChunkRef.current) return;
const nowCleaned = cleanTextForTTS(lastRawTextRef.current);
const nowUncommitted = nowCleaned
.slice(committedPositionRef.current)
.trim();
if (nowUncommitted.length >= 20) {
// Find a reasonable break point
let breakPoint = nowUncommitted.length;
const spaceIdx = nowUncommitted.lastIndexOf(" ", 50);
if (spaceIdx >= 15) breakPoint = spaceIdx;
const chunk = nowUncommitted.slice(0, breakPoint).trim();
if (chunk.length > 0) {
sendTextToTTS(chunk);
committedPositionRef.current += breakPoint;
hasSpokenFirstChunkRef.current = true;
}
}
}, 200);
}
// Flush timer for text ending with punctuation
if (
currentUncommitted.length > 0 &&
!isComplete &&
/[.!?]$/.test(currentUncommitted)
) {
flushTimerRef.current = setTimeout(() => {
const nowCleaned = cleanTextForTTS(lastRawTextRef.current);
const nowUncommitted = nowCleaned
.slice(committedPositionRef.current)
.trim();
if (nowUncommitted.length > 0) {
sendTextToTTS(nowUncommitted);
committedPositionRef.current = nowCleaned.length;
hasSpokenFirstChunkRef.current = true;
}
}, 250);
}
},
[autoPlayback, sendTextToTTS]
);
const stopTTS = useCallback((options?: { manual?: boolean }) => {
// Clear timers
if (flushTimerRef.current) {
clearTimeout(flushTimerRef.current);
flushTimerRef.current = null;
}
if (fastStartTimerRef.current) {
clearTimeout(fastStartTimerRef.current);
fastStartTimerRef.current = null;
}
if (loadingTimeoutRef.current) {
clearTimeout(loadingTimeoutRef.current);
loadingTimeoutRef.current = null;
}
if (endCheckIntervalRef.current) {
clearInterval(endCheckIntervalRef.current);
endCheckIntervalRef.current = null;
}
// Revoke blob URL to prevent memory leak
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
audioUrlRef.current = null;
}
// Stop audio element
if (audioElementRef.current) {
audioElementRef.current.pause();
audioElementRef.current.src = "";
audioElementRef.current = null;
}
// Cleanup MediaSource
if (
mediaSourceRef.current &&
mediaSourceRef.current.readyState === "open"
) {
try {
if (sourceBufferRef.current) {
mediaSourceRef.current.removeSourceBuffer(sourceBufferRef.current);
}
mediaSourceRef.current.endOfStream();
} catch {
// Ignore cleanup errors
}
}
mediaSourceRef.current = null;
sourceBufferRef.current = null;
pendingChunksRef.current = [];
isAppendingRef.current = false;
hasStartedPlaybackRef.current = false;
pendingTextRef.current = [];
isPlayingRef.current = false;
hasSignaledEndRef.current = false;
isConnectingRef.current = false;
streamEndedRef.current = false;
// Close WebSocket
if (wsRef.current) {
try {
wsRef.current.send(JSON.stringify({ type: "end" }));
wsRef.current.close();
} catch {
// Ignore
}
wsRef.current = null;
}
setIsTTSPlaying(false);
setIsTTSLoading(false);
if (options?.manual) {
setManualStopCount((count) => count + 1);
}
}, []);
const resetTTS = useCallback(() => {
stopTTS();
committedPositionRef.current = 0;
lastRawTextRef.current = "";
hasSpokenFirstChunkRef.current = false;
hasSignaledEndRef.current = false;
setSpokenText("");
}, [stopTTS]);
// Reset TTS state when voice auto-playback is disabled
// This prevents the mic button from being stuck disabled
const prevAutoPlaybackRef = useRef(autoPlayback);
useEffect(() => {
if (prevAutoPlaybackRef.current && !autoPlayback) {
// Auto-playback was just disabled, clean up TTS state
resetTTS();
}
prevAutoPlaybackRef.current = autoPlayback;
}, [autoPlayback, resetTTS]);
// Cleanup on unmount
useEffect(() => {
return () => {
if (flushTimerRef.current) clearTimeout(flushTimerRef.current);
if (fastStartTimerRef.current) clearTimeout(fastStartTimerRef.current);
if (loadingTimeoutRef.current) clearTimeout(loadingTimeoutRef.current);
if (endCheckIntervalRef.current)
clearInterval(endCheckIntervalRef.current);
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
}
if (wsRef.current) {
try {
wsRef.current.close();
} catch {
// Ignore
}
}
if (audioElementRef.current) {
try {
audioElementRef.current.pause();
audioElementRef.current.src = "";
} catch {
// Ignore
}
}
if (
mediaSourceRef.current &&
mediaSourceRef.current.readyState === "open"
) {
try {
mediaSourceRef.current.endOfStream();
} catch {
// Ignore
}
}
};
}, []);
return (
<VoiceModeContext.Provider
value={{
isTTSPlaying,
isTTSLoading,
spokenText,
streamTTS,
stopTTS,
manualStopCount,
resetTTS,
}}
>
{children}
</VoiceModeContext.Provider>
);
}
export function useVoiceMode(): VoiceModeContextType {
const context = useContext(VoiceModeContext);
if (!context) {
throw new Error("useVoiceMode must be used within VoiceModeProvider");
}
return context;
}

View File

@@ -129,6 +129,7 @@ const InputComboBox = ({
leftSearchIcon = false,
rightSection,
separatorLabel = "Other options",
showAddPrefix = false,
showOtherOptions = false,
...rest
}: WithoutStyles<InputComboBoxProps>) => {
@@ -156,14 +157,11 @@ const InputComboBox = ({
const visibleUnmatchedOptions =
hasSearchTerm && showOtherOptions ? unmatchedOptions : [];
// Whether to show the create option (only when no partial matches)
const showCreateOption =
!strict &&
hasSearchTerm &&
inputValue.trim() !== "" &&
matchedOptions.length === 0;
// Whether to show the create option (always show when typing in non-strict mode)
const showCreateOption = !strict && hasSearchTerm && inputValue.trim() !== "";
// Combined list for keyboard navigation (includes create option when shown)
// Only show matched options when searching (hide unmatched)
const allVisibleOptions = useMemo(() => {
const baseOptions = [...matchedOptions, ...visibleUnmatchedOptions];
if (showCreateOption) {
@@ -440,6 +438,7 @@ const InputComboBox = ({
inputValue={inputValue}
allowCreate={!strict}
showCreateOption={showCreateOption}
showAddPrefix={showAddPrefix}
/>
</>

View File

@@ -27,6 +27,8 @@ interface ComboBoxDropdownProps {
allowCreate: boolean;
/** Whether to show create option (pre-computed by parent) */
showCreateOption: boolean;
/** Show "Add" prefix in create option */
showAddPrefix: boolean;
}
/**
@@ -58,6 +60,7 @@ export const ComboBoxDropdown = forwardRef<
inputValue,
allowCreate,
showCreateOption,
showAddPrefix,
},
ref
) => {
@@ -132,6 +135,7 @@ export const ComboBoxDropdown = forwardRef<
inputValue={inputValue}
allowCreate={allowCreate}
showCreateOption={showCreateOption}
showAddPrefix={showAddPrefix}
/>
</div>,
document.body

View File

@@ -24,6 +24,8 @@ interface OptionsListProps {
allowCreate: boolean;
/** Whether to show create option (pre-computed by parent) */
showCreateOption: boolean;
/** Show "Add" prefix in create option */
showAddPrefix: boolean;
}
/**
@@ -45,6 +47,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
inputValue,
allowCreate,
showCreateOption,
showAddPrefix,
}) => {
// Index offset for other options when create option is shown
const indexOffset = showCreateOption ? 1 : 0;
@@ -70,7 +73,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
data-index={0}
role="option"
aria-selected={false}
aria-label={`Create "${inputValue}"`}
aria-label={`${showAddPrefix ? "Add" : "Create"} "${inputValue}"`}
onClick={(e) => {
e.stopPropagation();
onSelect({ value: inputValue, label: inputValue });
@@ -81,19 +84,48 @@ export const OptionsList: React.FC<OptionsListProps> = ({
onMouseEnter={() => onMouseEnter(0)}
onMouseMove={onMouseMove}
className={cn(
"px-3 py-2 cursor-pointer transition-colors",
"cursor-pointer transition-colors",
"flex items-center justify-between rounded-08",
highlightedIndex === 0 && "bg-background-tint-02",
"hover:bg-background-tint-02"
"hover:bg-background-tint-02",
showAddPrefix ? "px-1.5 py-1.5" : "px-3 py-2"
)}
>
<span className="font-main-ui-action text-text-04 truncate min-w-0">
{inputValue}
<span
className={cn(
"font-main-ui-action truncate min-w-0",
showAddPrefix ? "px-1" : ""
)}
>
{showAddPrefix ? (
<>
<span className="text-text-03">Add</span>
<span className="text-text-04">{` ${inputValue}`}</span>
</>
) : (
<span className="text-text-04">{inputValue}</span>
)}
</span>
<SvgPlus className="w-4 h-4 text-text-03 flex-shrink-0 ml-2" />
<SvgPlus
className={cn(
"w-4 h-4 flex-shrink-0",
showAddPrefix ? "text-text-04 mx-1" : "text-text-03 ml-2"
)}
/>
</div>
)}
{/* Separator - show when there are options to display */}
{separatorLabel &&
(matchedOptions.length > 0 ||
(!hasSearchTerm && unmatchedOptions.length > 0)) && (
<div className="px-3 py-1">
<Text as="p" text03 secondaryBody>
{separatorLabel}
</Text>
</div>
)}
{/* Matched/Filtered Options */}
{matchedOptions.map((option, idx) => {
const globalIndex = idx + indexOffset;
@@ -116,37 +148,27 @@ export const OptionsList: React.FC<OptionsListProps> = ({
);
})}
{/* Separator - only show if there are unmatched options and a search term */}
{hasSearchTerm && unmatchedOptions.length > 0 && (
<div className="px-3 py-2 pt-3">
<div className="border-t border-border-01 pt-2">
<Text as="p" text04 secondaryBody className="text-text-02">
{separatorLabel}
</Text>
</div>
</div>
)}
{/* Unmatched Options */}
{unmatchedOptions.map((option, idx) => {
const globalIndex = matchedOptions.length + idx + indexOffset;
const isExact = isExactMatch(option);
return (
<OptionItem
key={option.value}
option={option}
index={globalIndex}
fieldId={fieldId}
isHighlighted={globalIndex === highlightedIndex}
isSelected={value === option.value}
isExact={isExact}
onSelect={onSelect}
onMouseEnter={onMouseEnter}
onMouseMove={onMouseMove}
searchTerm={inputValue}
/>
);
})}
{/* Unmatched Options - only show when NOT searching */}
{!hasSearchTerm &&
unmatchedOptions.map((option, idx) => {
const globalIndex = matchedOptions.length + idx + indexOffset;
const isExact = isExactMatch(option);
return (
<OptionItem
key={option.value}
option={option}
index={globalIndex}
fieldId={fieldId}
isHighlighted={globalIndex === highlightedIndex}
isSelected={value === option.value}
isExact={isExact}
onSelect={onSelect}
onMouseEnter={onMouseEnter}
onMouseMove={onMouseMove}
searchTerm={inputValue}
/>
);
})}
</>
);
};

View File

@@ -40,6 +40,8 @@ export interface InputComboBoxProps
rightSection?: React.ReactNode;
/** Label for the separator between matched and unmatched options */
separatorLabel?: string;
/** Show "Add" prefix in create option (e.g., "Add [value]") */
showAddPrefix?: boolean;
/**
* When true, keep non-matching options visible under a separator while searching.
* Defaults to false so search results are strictly filtered.

View File

@@ -748,6 +748,7 @@ function ChatPreferencesSettings() {
updateUserShortcuts,
updateUserDefaultModel,
updateUserDefaultAppMode,
updateUserVoiceSettings,
} = useUser();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
const settings = useSettingsContext();
@@ -933,6 +934,64 @@ function ChatPreferencesSettings() {
{user?.preferences?.shortcut_enabled && <PromptShortcuts />}
</Card>
</Section>
<Section gap={0.75}>
<Content
title="Voice"
sizePreset="main-content"
variant="section"
widthVariant="full"
/>
<Card>
<InputLayouts.Horizontal
title="Auto-Send"
description="Automatically send voice input when recording stops."
>
<Switch
checked={user?.preferences.voice_auto_send ?? false}
onCheckedChange={(checked) => {
void updateUserVoiceSettings({ auto_send: checked });
}}
/>
</InputLayouts.Horizontal>
<InputLayouts.Horizontal
title="Auto-Playback"
description="Automatically play voice responses."
>
<Switch
checked={user?.preferences.voice_auto_playback ?? false}
onCheckedChange={(checked) => {
void updateUserVoiceSettings({ auto_playback: checked });
}}
/>
</InputLayouts.Horizontal>
<InputLayouts.Horizontal
title="Playback Speed"
description="Adjust the speed of voice playback."
>
<div className="flex items-center gap-3">
<input
type="range"
min="0.5"
max="2"
step="0.1"
value={user?.preferences.voice_playback_speed ?? 1}
onChange={(e) => {
void updateUserVoiceSettings({
playback_speed: parseFloat(e.target.value),
});
}}
className="w-24 h-2 rounded-lg appearance-none cursor-pointer bg-background-neutral-02"
/>
<span className="text-sm text-text-02 w-10">
{(user?.preferences.voice_playback_speed ?? 1).toFixed(1)}x
</span>
</div>
</InputLayouts.Horizontal>
</Card>
</Section>
</Section>
);
}

View File

@@ -56,6 +56,9 @@ import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { useQueryController } from "@/providers/QueryControllerProvider";
import { Section } from "@/layouts/general-layouts";
import Spacer from "@/refresh-components/Spacer";
import MicrophoneButton from "@/sections/input/MicrophoneButton";
import RecordingWaveform from "@/sections/input/RecordingWaveform";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
const MIN_INPUT_HEIGHT = 44;
const MAX_INPUT_HEIGHT = 200;
@@ -114,6 +117,12 @@ const AppInputBar = React.memo(
}: AppInputBarProps) => {
// Internal message state - kept local to avoid parent re-renders on every keystroke
const [message, setMessage] = useState(initialMessage);
const [isRecording, setIsRecording] = useState(false);
const [isMuted, setIsMuted] = useState(false);
const stopRecordingRef = useRef<(() => Promise<string | null>) | null>(
null
);
const setMutedRef = useRef<((muted: boolean) => void) | null>(null);
const textAreaRef = useRef<HTMLTextAreaElement>(null);
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
const filesWrapperRef = useRef<HTMLDivElement>(null);
@@ -121,6 +130,16 @@ const AppInputBar = React.memo(
const containerRef = useRef<HTMLDivElement>(null);
const { user } = useUser();
const { isClassifying, classification } = useQueryController();
const { stopTTS } = useVoiceMode();
// Wrapper for onSubmit that stops TTS first to prevent overlapping voices
const handleSubmit = useCallback(
(text: string) => {
stopTTS();
onSubmit(text);
},
[stopTTS, onSubmit]
);
// Expose reset and focus methods to parent via ref
React.useImperativeHandle(ref, () => ({
@@ -543,6 +562,21 @@ const AppInputBar = React.memo(
disabled={disabled}
/>
</div>
<MicrophoneButton
onTranscription={(text) => setMessage(text)}
disabled={disabled || chatState === "streaming"}
autoSend={user?.preferences?.voice_auto_send ?? false}
autoListen={user?.preferences?.voice_auto_playback ?? false}
chatState={chatState}
onRecordingChange={setIsRecording}
stopRecordingRef={stopRecordingRef}
onRecordingStart={() => setMessage("")}
onAutoSend={(text) => {
handleSubmit(text);
}}
onMuteChange={setIsMuted}
setMutedRef={setMutedRef}
/>
<Button
id="onyx-chat-input-send-button"
icon={
@@ -575,7 +609,7 @@ const AppInputBar = React.memo(
ref={containerRef}
id="onyx-chat-input"
className={cn(
"w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16"
"w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16 relative"
// # Note (from @raunakab):
//
// `shadow-01` extends ~14px below the element (2px offset + 12px blur).
@@ -639,9 +673,11 @@ const AppInputBar = React.memo(
style={{ scrollbarWidth: "thin" }}
aria-multiline={true}
placeholder={
isSearchMode
? "Search connected sources"
: "How can I help you today?"
isRecording
? "Listening..."
: isSearchMode
? "Search connected sources"
: "How can I help you today?"
}
value={message}
onKeyDown={(event) => {
@@ -658,7 +694,7 @@ const AppInputBar = React.memo(
!isClassifying &&
!hasUploadingFiles
) {
onSubmit(message);
handleSubmit(message);
}
}
}}
@@ -722,7 +758,7 @@ const AppInputBar = React.memo(
if (chatState == "streaming") {
stopGenerating();
} else if (message) {
onSubmit(message);
handleSubmit(message);
}
}}
prominence="tertiary"
@@ -733,6 +769,28 @@ const AppInputBar = React.memo(
</div>
{chatControls}
{/* Recording waveform - position depends on session state:
- Fresh chat (input centered): bar floats BELOW input
- Subsequent turns (input at bottom): bar floats ABOVE input */}
{isRecording && (
<div
className={cn(
"absolute left-0 right-0 px-1.5",
appFocus.isNewSession()
? "-bottom-[46px]" // Fresh chat: below input
: "-top-[46px]" // Subsequent turns: above input
)}
>
<RecordingWaveform
isRecording={isRecording}
isMuted={isMuted}
onMuteToggle={() => {
setMutedRef.current?.(!isMuted);
}}
/>
</div>
)}
</div>
</Disabled>
);

View File

@@ -0,0 +1,224 @@
"use client";
import { useCallback, useEffect, useRef } from "react";
import { Button } from "@opal/components";
import { SvgMicrophone } from "@opal/icons";
import { useVoiceRecorder } from "@/hooks/useVoiceRecorder";
import { useVoiceMode } from "@/providers/VoiceModeProvider";
import { toast } from "@/hooks/useToast";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { ChatState } from "@/app/app/interfaces";
interface MicrophoneButtonProps {
onTranscription: (text: string) => void;
disabled?: boolean;
autoSend?: boolean;
/** Called with transcribed text when autoSend is enabled */
onAutoSend?: (text: string) => void;
/**
* Internal prop: auto-start listening when TTS finishes or chat response completes.
* Tied to voice_auto_playback user preference.
* Enables conversation flow: speak → AI responds → auto-listen again.
* Note: autoSend is separate - it controls whether message auto-submits after recording.
*/
autoListen?: boolean;
/** Current chat state - used to detect when response streaming finishes */
chatState?: ChatState;
/** Called when recording state changes */
onRecordingChange?: (isRecording: boolean) => void;
/** Ref to expose stop recording function to parent */
stopRecordingRef?: React.MutableRefObject<
(() => Promise<string | null>) | null
>;
/** Called when recording starts to clear input */
onRecordingStart?: () => void;
/** Called when mute state changes */
onMuteChange?: (isMuted: boolean) => void;
/** Ref to expose setMuted function to parent */
setMutedRef?: React.MutableRefObject<((muted: boolean) => void) | null>;
}
function MicrophoneButton({
onTranscription,
disabled = false,
autoSend = false,
onAutoSend,
autoListen = false,
chatState,
onRecordingChange,
stopRecordingRef,
onRecordingStart,
onMuteChange,
setMutedRef,
}: MicrophoneButtonProps) {
const { isTTSPlaying, isTTSLoading, manualStopCount } = useVoiceMode();
// Refs for tracking state across renders
const wasTTSPlayingRef = useRef(false);
const manualStopRequestedRef = useRef(false);
const lastHandledManualStopCountRef = useRef(manualStopCount);
const autoListenCooldownTimerRef = useRef<NodeJS.Timeout | null>(null);
// Handler for VAD-triggered auto-send (when server detects silence)
const handleFinalTranscript = useCallback(
(text: string) => {
onTranscription(text);
const isManualStop = manualStopRequestedRef.current;
// Only auto-send if chat is ready for input (not streaming)
if (!isManualStop && autoSend && onAutoSend && chatState === "input") {
onAutoSend(text);
}
},
[onTranscription, autoSend, onAutoSend, chatState]
);
const {
isRecording,
isProcessing,
isMuted,
error,
liveTranscript,
startRecording,
stopRecording,
setMuted,
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
// Expose stopRecording to parent
useEffect(() => {
if (stopRecordingRef) {
stopRecordingRef.current = stopRecording;
}
}, [stopRecording, stopRecordingRef]);
// Expose setMuted to parent
useEffect(() => {
if (setMutedRef) {
setMutedRef.current = setMuted;
}
}, [setMuted, setMutedRef]);
// Notify parent when mute state changes
useEffect(() => {
onMuteChange?.(isMuted);
}, [isMuted, onMuteChange]);
// Notify parent when recording state changes
useEffect(() => {
onRecordingChange?.(isRecording);
}, [isRecording, onRecordingChange]);
// Update input with live transcript as user speaks
useEffect(() => {
if (isRecording && liveTranscript) {
onTranscription(liveTranscript);
}
}, [isRecording, liveTranscript, onTranscription]);
const handleClick = useCallback(async () => {
if (isRecording) {
// When recording, clicking the mic button stops recording
manualStopRequestedRef.current = true;
try {
await stopRecording();
} finally {
manualStopRequestedRef.current = false;
}
} else {
try {
// Clear input before starting recording
onRecordingStart?.();
await startRecording();
} catch {
toast.error("Could not access microphone");
}
}
}, [isRecording, startRecording, stopRecording, onRecordingStart]);
// Auto-start listening shortly after TTS finishes (only if autoListen is enabled).
// Small cooldown reduces playback bleed being re-captured by the microphone.
useEffect(() => {
if (autoListenCooldownTimerRef.current) {
clearTimeout(autoListenCooldownTimerRef.current);
autoListenCooldownTimerRef.current = null;
}
const stoppedManually =
manualStopCount !== lastHandledManualStopCountRef.current;
if (
wasTTSPlayingRef.current &&
!isTTSPlaying &&
!isTTSLoading &&
autoListen &&
!disabled &&
!isRecording &&
!stoppedManually
) {
autoListenCooldownTimerRef.current = setTimeout(() => {
autoListenCooldownTimerRef.current = null;
if (
!autoListen ||
disabled ||
isRecording ||
isTTSPlaying ||
isTTSLoading
) {
return;
}
startRecording().catch(() => {
// Silently ignore auto-start failures
});
}, 400);
}
if (stoppedManually) {
lastHandledManualStopCountRef.current = manualStopCount;
}
wasTTSPlayingRef.current = isTTSPlaying || isTTSLoading;
}, [
isTTSPlaying,
isTTSLoading,
autoListen,
disabled,
isRecording,
startRecording,
manualStopCount,
]);
useEffect(() => {
return () => {
if (autoListenCooldownTimerRef.current) {
clearTimeout(autoListenCooldownTimerRef.current);
autoListenCooldownTimerRef.current = null;
}
};
}, []);
useEffect(() => {
if (error) {
toast.error(error);
}
}, [error]);
// Icon: show loader when processing, otherwise mic
const icon = isProcessing ? SimpleLoader : SvgMicrophone;
// Disable when processing or TTS is playing (don't want to pick up TTS audio)
const isDisabled = disabled || isProcessing || isTTSPlaying || isTTSLoading;
// Recording = darkened (primary), not recording = light (tertiary)
const prominence = isRecording ? "primary" : "tertiary";
return (
<Button
icon={icon}
disabled={isDisabled}
onClick={handleClick}
aria-label={isRecording ? "Stop recording" : "Start recording"}
prominence={prominence}
/>
);
}
export default MicrophoneButton;

View File

@@ -0,0 +1,97 @@
"use client";
import { useEffect, useState, useMemo } from "react";
import { cn } from "@/lib/utils";
import { Button } from "@opal/components";
import { SvgMicrophone, SvgMicrophoneOff } from "@opal/icons";
interface RecordingWaveformProps {
isRecording: boolean;
isMuted?: boolean;
onMuteToggle?: () => void;
}
function RecordingWaveform({
isRecording,
isMuted = false,
onMuteToggle,
}: RecordingWaveformProps) {
const [elapsedSeconds, setElapsedSeconds] = useState(0);
// Reset and start timer when recording starts
useEffect(() => {
if (!isRecording) {
setElapsedSeconds(0);
return;
}
const interval = setInterval(() => {
setElapsedSeconds((prev) => prev + 1);
}, 1000);
return () => clearInterval(interval);
}, [isRecording]);
// Format time as MM:SS
const formattedTime = useMemo(() => {
const minutes = Math.floor(elapsedSeconds / 60);
const seconds = elapsedSeconds % 60;
return `${minutes.toString().padStart(2, "0")}:${seconds
.toString()
.padStart(2, "0")}`;
}, [elapsedSeconds]);
// Generate random bar heights for waveform animation
const bars = useMemo(() => {
return Array.from({ length: 50 }, (_, i) => ({
id: i,
// Create a wave pattern with some randomness
baseHeight: Math.sin(i * 0.3) * 6 + 8,
delay: i * 0.02,
}));
}, []);
if (!isRecording) {
return null;
}
return (
<div className="flex items-center gap-3 px-3 py-2 bg-background-tint-00 rounded-12 min-h-[32px]">
{/* Waveform visualization */}
<div className="flex-1 flex items-center justify-center gap-[2px] h-4 overflow-hidden">
{bars.map((bar) => (
<div
key={bar.id}
className={cn(
"w-[2px] bg-text-03 rounded-full",
!isMuted && "animate-waveform"
)}
style={{
// When muted, show flat bars (2px height), otherwise animate with base height
height: isMuted ? "2px" : `${bar.baseHeight}px`,
animationDelay: isMuted ? undefined : `${bar.delay}s`,
}}
/>
))}
</div>
{/* Timer */}
<span className="font-mono text-xs text-text-03 tabular-nums shrink-0">
{formattedTime}
</span>
{/* Mute button */}
{onMuteToggle && (
<Button
icon={isMuted ? SvgMicrophoneOff : SvgMicrophone}
onClick={onMuteToggle}
prominence="tertiary"
size="sm"
aria-label={isMuted ? "Unmute microphone" : "Mute microphone"}
/>
)}
</div>
);
}
export default RecordingWaveform;

View File

@@ -19,7 +19,35 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE
import { CombinedSettings } from "@/interfaces/settings";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import SidebarBody from "@/sections/sidebar/SidebarBody";
import { SvgArrowUpCircle } from "@opal/icons";
import {
SvgActions,
SvgActivity,
SvgArrowUpCircle,
SvgBarChart,
SvgCpu,
SvgFileText,
SvgFolder,
SvgGlobe,
SvgArrowExchange,
SvgImage,
SvgKey,
SvgOnyxLogo,
SvgOnyxOctagon,
SvgSearch,
SvgServer,
SvgSettings,
SvgShield,
SvgThumbsUp,
SvgUploadCloud,
SvgUser,
SvgUsers,
SvgZoomIn,
SvgPaintBrush,
SvgDiscordMono,
SvgWallet,
SvgAudio,
} from "@opal/icons";
import SvgMcp from "@opal/icons/mcp";
import { ADMIN_PATHS, sidebarItem } from "@/lib/admin-routes";
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
@@ -105,6 +133,11 @@ const collections = (
sidebarItem(ADMIN_PATHS.LLM_MODELS),
sidebarItem(ADMIN_PATHS.WEB_SEARCH),
sidebarItem(ADMIN_PATHS.IMAGE_GENERATION),
{
name: "Voice",
icon: SvgAudio,
link: "/admin/configuration/voice",
},
sidebarItem(ADMIN_PATHS.CODE_INTERPRETER),
...(!enableCloud && vectorDbEnabled
? [
@@ -131,11 +164,7 @@ const collections = (
? [
{
name: "Permissions",
items: [
// TODO: Uncomment once Users v2 page is complete
// sidebarItem(ADMIN_PATHS.USERS_V2),
sidebarItem(ADMIN_PATHS.SCIM),
],
items: [sidebarItem(ADMIN_PATHS.SCIM)],
},
]
: []),