mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-05 23:55:47 +00:00
Compare commits
17 Commits
nikg/admin
...
voice-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bba77749c3 | ||
|
|
3e9a66c8ff | ||
|
|
548b9d9e0e | ||
|
|
0d3967baee | ||
|
|
6ed806eebb | ||
|
|
3b6a35b2c4 | ||
|
|
62e612f85f | ||
|
|
b375b7f0ff | ||
|
|
c158ae2622 | ||
|
|
698494626f | ||
|
|
93cefe7ef0 | ||
|
|
8a326c4089 | ||
|
|
0c5410f429 | ||
|
|
0b05b9b235 | ||
|
|
59d8a988bd | ||
|
|
6d08cfb25a | ||
|
|
53a5ee2a6e |
@@ -0,0 +1,119 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX ix_voice_provider_one_default_stt
|
||||
ON voice_provider (is_default_stt)
|
||||
WHERE is_default_stt = true
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX ix_voice_provider_one_default_tts
|
||||
ON voice_provider (is_default_tts)
|
||||
WHERE is_default_tts = true
|
||||
"""
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("preferred_voice", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "preferred_voice")
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_tts")
|
||||
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_stt")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -28,6 +28,7 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -1599,6 +1600,91 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
_websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail=f"Authentication verification failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
logger.info(f"WS auth: user found: {user.email}")
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
logger.info(f"WS auth: user verified: {user.email}, role={user.role}")
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.info(f"WS auth: authentication successful for {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -284,6 +284,12 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
preferred_voice: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -2964,6 +2970,63 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
257
backend/onyx/db/voice.py
Normal file
257
backend/onyx/db/voice.py
Normal file
@@ -0,0 +1,257 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_auto_send(
|
||||
db_session: Session, user_id: UUID, auto_send: bool
|
||||
) -> None:
|
||||
"""Update user's voice auto-send setting."""
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(voice_auto_send=auto_send) # type: ignore[arg-type]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_voice_auto_playback(
|
||||
db_session: Session, user_id: UUID, auto_playback: bool
|
||||
) -> None:
|
||||
"""Update user's voice auto-playback setting."""
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(voice_auto_playback=auto_playback) # type: ignore[arg-type]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_voice_playback_speed(
|
||||
db_session: Session, user_id: UUID, speed: float
|
||||
) -> None:
|
||||
"""Update user's voice playback speed setting."""
|
||||
# Clamp to valid range
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(voice_playback_speed=speed) # type: ignore[arg-type]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_preferred_voice(
|
||||
db_session: Session, user_id: UUID, voice: str | None
|
||||
) -> None:
|
||||
"""Update user's preferred voice setting."""
|
||||
db_session.execute(
|
||||
update(User).where(User.id == user_id).values(preferred_voice=voice) # type: ignore[arg-type]
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
preferred_voice: str | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings. Only updates fields that are not None."""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(0.5, min(2.0, playback_speed))
|
||||
if preferred_voice is not None:
|
||||
values["preferred_voice"] = preferred_voice
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.commit()
|
||||
@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -439,12 +442,65 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data_str = await redis.get(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
# Delete the token after retrieval (single-use)
|
||||
await redis.delete(redis_key)
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -129,6 +130,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -85,6 +85,12 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
preferred_voice: str | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -164,6 +170,10 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
preferred_voice=user.preferred_voice,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -240,6 +250,13 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
preferred_voice: str | None = None
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
0
backend/onyx/server/manage/voice/__init__.py
Normal file
0
backend/onyx/server/manage/voice/__init__.py
Normal file
282
backend/onyx/server/manage/voice/api.py
Normal file
282
backend/onyx/server/manage/voice/api.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = request.target_uri or request.api_base
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Test a voice provider connection."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = request.target_uri or request.api_base
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
# Test the provider by getting available voices (lightweight check)
|
||||
try:
|
||||
voices = provider.get_available_voices()
|
||||
if not voices:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provider returned no available voices.",
|
||||
)
|
||||
except NotImplementedError:
|
||||
# Provider not fully implemented yet (Azure, ElevenLabs placeholders)
|
||||
pass
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connection test failed: {str(e)}",
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise HTTPException(status_code=404, detail="Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return provider.get_available_voices()
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return provider.get_available_voices()
|
||||
90
backend/onyx/server/manage/voice/models.py
Normal file
90
backend/onyx/server/manage/voice/models.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
"""Request model for text-to-speech synthesis."""
|
||||
|
||||
text: str = Field(..., min_length=1, max_length=4096)
|
||||
voice: str | None = None
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0)
|
||||
207
backend/onyx/server/manage/voice/user_api.py
Normal file
207
backend/onyx/server/manage/voice/user_api.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
audio_data = await audio.read()
|
||||
if len(audio_data) > MAX_AUDIO_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Transcription failed: {str(exc)}",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice, or user's preferred voice, or provider default
|
||||
final_voice = voice or user.preferred_voice or provider_db.default_voice
|
||||
final_speed = speed or user.voice_playback_speed or 1.0
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
preferred_voice=request.preferred_voice,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
await store_ws_token(token, str(user.id))
|
||||
return WSTokenResponse(token=token)
|
||||
778
backend/onyx/server/manage/voice/websocket_api.py
Normal file
778
backend/onyx/server/manage/voice/websocket_api.py
Normal file
@@ -0,0 +1,778 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.info(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.info(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.info(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.info(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API.
|
||||
|
||||
Buffers all text, then sends to ElevenLabs when complete.
|
||||
This is more reliable than streaming chunks incrementally.
|
||||
"""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
text_buffer: list[str] = [] # Buffer text chunks until "end"
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except RuntimeError as e:
|
||||
if "disconnect" in str(e).lower():
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
raise
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if key != "type" and isinstance(value, str) and value:
|
||||
text = value
|
||||
break
|
||||
if text:
|
||||
# Buffer text instead of sending immediately
|
||||
text_buffer.append(text)
|
||||
logger.info(
|
||||
f"Streaming synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
if text_buffer:
|
||||
# Combine all buffered text
|
||||
full_text = " ".join(text_buffer)
|
||||
logger.info(
|
||||
f"Streaming synthesis: sending full text ({len(full_text)} chars): "
|
||||
f"'{full_text[:100]}...'"
|
||||
)
|
||||
|
||||
# Start audio receiver
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Send all text at once
|
||||
await synthesizer.send_text(full_text)
|
||||
logger.info(
|
||||
"Streaming synthesis: full text sent to synthesizer"
|
||||
)
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if "disconnect" not in str(e).lower():
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(websocket: WebSocket, provider: Any) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream."""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if key != "type" and isinstance(value, str) and value:
|
||||
text = value
|
||||
break
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.info(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
if "disconnect" not in str(e).lower():
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except RuntimeError as e:
|
||||
if "disconnect" in str(e).lower():
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
else:
|
||||
logger.error(f"WebSocket synthesize: runtime error: {e}")
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "disconnect" in error_str or "websocket.close" in error_str:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
else:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
0
backend/onyx/voice/__init__.py
Normal file
0
backend/onyx/voice/__init__.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
175
backend/onyx/voice/interface.py
Normal file
175
backend/onyx/voice/interface.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptResult:
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
0
backend/onyx/voice/providers/__init__.py
Normal file
0
backend/onyx/voice/providers/__init__.py
Normal file
471
backend/onyx/voice/providers/azure.py
Normal file
471
backend/onyx/voice/providers/azure.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore[import-not-found]
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(self, api_key: str, region: str, voice: str = "en-US-JennyNeural"):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.voice = voice
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Start synthesis asynchronously
|
||||
self._synthesizer.speak_text_async(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
self.speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI."""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (fallback to first label)
|
||||
patterns = [
|
||||
r"https?://([^.]+)\.(?:tts|stt)\.speech\.microsoft\.com",
|
||||
r"https?://([^.]+)\.api\.cognitive\.microsoft\.com",
|
||||
r"https?://([^.]+)\.cognitiveservices\.azure\.com",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, uri)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
|
||||
<voice name='{voice_name}'>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Speech region required for Azure streaming transcription")
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
_ = speed # Azure SDK streaming path does not currently support runtime speed control.
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Speech region required for Azure streaming TTS")
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
751
backend/onyx/voice/providers/elevenlabs.py
Normal file
751
backend/onyx/voice/providers/elevenlabs.py
Normal file
@@ -0,0 +1,751 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = 24000, # What frontend sends
|
||||
target_sample_rate: int = 16000, # What ElevenLabs expects
|
||||
language_code: str = "en",
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection
|
||||
url = (
|
||||
f"wss://api.elevenlabs.io/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs=1.0"
|
||||
f"&vad_threshold=0.4"
|
||||
f"&min_speech_duration_ms=100"
|
||||
f"&min_silence_duration_ms=300"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
|
||||
# ElevenLabs uses message_type field
|
||||
msg_type = data.get("message_type", data.get("type", "")) # type: ignore[possibly-undefined]
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == "error":
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types from ElevenLabs Scribe API
|
||||
if msg_type == "session_started":
|
||||
# Session started successfully
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == "partial_transcript":
|
||||
# Partial transcript (interim result)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == "committed_transcript":
|
||||
# Final/committed transcript (VAD detected end of utterance)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "utterance_end":
|
||||
# VAD detected end of speech
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "session_ended":
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
url = (
|
||||
f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": [
|
||||
120,
|
||||
160,
|
||||
250,
|
||||
290,
|
||||
], # Optimized chunk sizes for streaming
|
||||
},
|
||||
"xi_api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.elevenlabs.io"
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, _speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier (not directly supported, ignored)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model
|
||||
# Frontend sends PCM16 at 24kHz, but ElevenLabs expects 16kHz
|
||||
# The transcriber will resample automatically
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=24000, # What frontend sends
|
||||
target_sample_rate=16000, # What ElevenLabs expects
|
||||
language_code="en",
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, _speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
# Use mp3_44100_64 for streaming - good balance of quality and chunk size
|
||||
output_format="mp3_44100_64",
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
567
backend/onyx/voice/providers/openai.py
Normal file
567
backend/onyx/voice/providers/openai.py
Normal file
@@ -0,0 +1,567 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "whisper-1"):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
url = "wss://api.openai.com/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == "error":
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == "input_audio_buffer.speech_started":
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.speech_stopped":
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.committed":
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == "conversation.item.input_audio_transcription.delta":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif (
|
||||
msg_type
|
||||
== "conversation.item.input_audio_transcription.completed"
|
||||
):
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
"transcription_session.created",
|
||||
"transcription_session.updated",
|
||||
"conversation.item.created",
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = "https://api.openai.com/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
20
web/lib/opal/src/icons/audio.tsx
Normal file
20
web/lib/opal/src/icons/audio.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAudio = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 32 32"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M4 20V12M10 28V4M22 22V10M28 18V14M16 20V12"
|
||||
strokeWidth={2.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAudio;
|
||||
@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
|
||||
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
|
||||
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudio } from "@opal/icons/audio";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
@@ -105,6 +106,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
|
||||
export { default as SvgMcp } from "@opal/icons/mcp";
|
||||
export { default as SvgMenu } from "@opal/icons/menu";
|
||||
export { default as SvgMicrophone } from "@opal/icons/microphone";
|
||||
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
|
||||
export { default as SvgMinus } from "@opal/icons/minus";
|
||||
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
|
||||
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
{/* Microphone body */}
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
{/* Diagonal slash */}
|
||||
<path
|
||||
d="M2 2L14 14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophoneOff;
|
||||
21
web/lib/opal/src/icons/microphone.tsx
Normal file
21
web/lib/opal/src/icons/microphone.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophone = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophone;
|
||||
@@ -58,7 +58,7 @@ const nextConfig = {
|
||||
{
|
||||
key: "Permissions-Policy",
|
||||
value:
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
8
web/package-lock.json
generated
8
web/package-lock.json
generated
@@ -105,6 +105,7 @@
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "19.2.10",
|
||||
"@types/react-dom": "19.2.3",
|
||||
"@types/sbd": "^1.0.5",
|
||||
"@types/stats.js": "^0.17.4",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251222.1",
|
||||
@@ -5543,6 +5544,13 @@
|
||||
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/sbd": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/sbd/-/sbd-1.0.5.tgz",
|
||||
"integrity": "sha512-60PxBBWhg0C3yb5bTP+wwWYGTKMcuB0S6mTEa1sedMC79tYY0Ei7YjU4qsWzGn++lWscLQde16SnElJrf5/aTw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/stack-utils": {
|
||||
"version": "2.0.3",
|
||||
"dev": true,
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "19.2.10",
|
||||
"@types/react-dom": "19.2.3",
|
||||
"@types/sbd": "^1.0.5",
|
||||
"@types/stats.js": "^0.17.4",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251222.1",
|
||||
|
||||
4
web/public/ElevenLabs.svg
Normal file
4
web/public/ElevenLabs.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="black"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 192 B |
@@ -0,0 +1,466 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { FunctionComponent, useMemo, useState, useEffect } from "react";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { Vertical, Horizontal } from "@/layouts/input-layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgArrowExchange, SvgOnyxLogo } from "@opal/icons";
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
interface VoiceOption {
|
||||
value: string;
|
||||
label: string;
|
||||
}
|
||||
|
||||
interface LLMProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
}
|
||||
|
||||
interface ApiKeyOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface VoiceProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: string;
|
||||
is_default_stt: boolean;
|
||||
is_default_tts: boolean;
|
||||
stt_model: string | null;
|
||||
tts_model: string | null;
|
||||
default_voice: string | null;
|
||||
has_api_key: boolean;
|
||||
target_uri: string | null;
|
||||
}
|
||||
|
||||
interface VoiceProviderSetupModalProps {
|
||||
providerType: string;
|
||||
existingProvider: VoiceProviderView | null;
|
||||
mode: "stt" | "tts";
|
||||
defaultModelId?: string | null;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
const PROVIDER_LABELS: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure Speech Services",
|
||||
elevenlabs: "ElevenLabs",
|
||||
};
|
||||
|
||||
const PROVIDER_API_KEY_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/api-keys",
|
||||
azure: "https://portal.azure.com/",
|
||||
elevenlabs: "https://elevenlabs.io/app/settings/api-keys",
|
||||
};
|
||||
|
||||
const PROVIDER_LOGO_URLS: Record<string, string> = {
|
||||
openai: "/Openai.svg",
|
||||
azure: "/Azure.png",
|
||||
elevenlabs: "/ElevenLabs.svg",
|
||||
};
|
||||
|
||||
const PROVIDER_DOCS_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/docs/guides/text-to-speech",
|
||||
azure: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
elevenlabs: "https://elevenlabs.io/docs",
|
||||
};
|
||||
|
||||
const OPENAI_STT_MODELS = [{ id: "whisper-1", name: "Whisper v1" }];
|
||||
|
||||
const OPENAI_TTS_MODELS = [
|
||||
{ id: "tts-1", name: "TTS-1" },
|
||||
{ id: "tts-1-hd", name: "TTS-1 HD" },
|
||||
];
|
||||
|
||||
// Map model IDs from cards to actual API model IDs
|
||||
const MODEL_ID_MAP: Record<string, string> = {
|
||||
"tts-1": "tts-1",
|
||||
"tts-1-hd": "tts-1-hd",
|
||||
whisper: "whisper-1",
|
||||
};
|
||||
|
||||
export default function VoiceProviderSetupModal({
|
||||
providerType,
|
||||
existingProvider,
|
||||
mode,
|
||||
defaultModelId,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: VoiceProviderSetupModalProps) {
|
||||
// Map the card model ID to the actual API model ID
|
||||
// Prioritize defaultModelId (from the clicked card) over stored value
|
||||
const initialTtsModel = defaultModelId
|
||||
? MODEL_ID_MAP[defaultModelId] ?? "tts-1"
|
||||
: existingProvider?.tts_model ?? "tts-1";
|
||||
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiKeyChanged, setApiKeyChanged] = useState(false);
|
||||
const [targetUri, setTargetUri] = useState(
|
||||
existingProvider?.target_uri ?? ""
|
||||
);
|
||||
const [selectedLlmProviderId, setSelectedLlmProviderId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
const [sttModel, setSttModel] = useState(
|
||||
existingProvider?.stt_model ?? "whisper-1"
|
||||
);
|
||||
const [ttsModel, setTtsModel] = useState(initialTtsModel);
|
||||
const [defaultVoice, setDefaultVoice] = useState(
|
||||
existingProvider?.default_voice ?? ""
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
// Dynamic voices fetched from backend
|
||||
const [voiceOptions, setVoiceOptions] = useState<VoiceOption[]>([]);
|
||||
const [isLoadingVoices, setIsLoadingVoices] = useState(false);
|
||||
|
||||
// Existing OpenAI LLM providers for API key reuse
|
||||
const [existingApiKeyOptions, setExistingApiKeyOptions] = useState<
|
||||
ApiKeyOption[]
|
||||
>([]);
|
||||
const [llmProviderMap, setLlmProviderMap] = useState<Map<string, number>>(
|
||||
new Map()
|
||||
);
|
||||
|
||||
// Fetch existing OpenAI LLM providers (for API key reuse)
|
||||
useEffect(() => {
|
||||
if (providerType !== "openai") return;
|
||||
|
||||
fetch("/api/admin/llm/provider")
|
||||
.then((res) => res.json())
|
||||
.then((data: LLMProviderView[]) => {
|
||||
const openaiProviders = data.filter(
|
||||
(p) => p.provider === "openai" && p.api_key
|
||||
);
|
||||
const options: ApiKeyOption[] = openaiProviders.map((p) => ({
|
||||
value: p.api_key!,
|
||||
label: p.api_key!,
|
||||
description: `Used for LLM provider ${p.name}`,
|
||||
}));
|
||||
setExistingApiKeyOptions(options);
|
||||
|
||||
// Map masked API keys to provider IDs for lookup on selection
|
||||
const providerMap = new Map<string, number>();
|
||||
openaiProviders.forEach((p) => {
|
||||
if (p.api_key) {
|
||||
providerMap.set(p.api_key, p.id);
|
||||
}
|
||||
});
|
||||
setLlmProviderMap(providerMap);
|
||||
})
|
||||
.catch(() => {
|
||||
setExistingApiKeyOptions([]);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
// Fetch voices on mount (works without API key for ElevenLabs/OpenAI)
|
||||
useEffect(() => {
|
||||
setIsLoadingVoices(true);
|
||||
fetch(`/api/admin/voice/voices?provider_type=${providerType}`)
|
||||
.then((res) => res.json())
|
||||
.then((data: Array<{ id: string; name: string }>) => {
|
||||
const options = data.map((v) => ({ value: v.id, label: v.name }));
|
||||
setVoiceOptions(options);
|
||||
// Set default voice to first option if not already set
|
||||
const firstOption = options[0];
|
||||
if (firstOption) {
|
||||
setDefaultVoice((prev) => prev || firstOption.value);
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
setVoiceOptions([]);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingVoices(false);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
const isEditing = !!existingProvider;
|
||||
const label = PROVIDER_LABELS[providerType] ?? providerType;
|
||||
|
||||
// Create a logo arrangement component for the modal header
|
||||
const LogoArrangement: FunctionComponent<IconProps> = useMemo(() => {
|
||||
const Component: FunctionComponent<IconProps> = () => (
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="flex items-center justify-center size-7 shrink-0 overflow-clip">
|
||||
<Image
|
||||
src={PROVIDER_LOGO_URLS[providerType] ?? "/Openai.svg"}
|
||||
alt={`${label} logo`}
|
||||
width={24}
|
||||
height={24}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-4 p-0.5 shrink-0">
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
return Component;
|
||||
}, [providerType, label]);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!isEditing && !apiKey && !selectedLlmProviderId) {
|
||||
toast.error("API key is required");
|
||||
return;
|
||||
}
|
||||
|
||||
if (providerType === "azure" && !isEditing && !targetUri) {
|
||||
toast.error("Target URI is required");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
// Test the connection first (skip if reusing LLM provider key - it's already validated)
|
||||
if (!selectedLlmProviderId) {
|
||||
const testResponse = await fetch("/api/admin/voice/providers/test", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: providerType,
|
||||
api_key: apiKeyChanged ? apiKey : undefined,
|
||||
target_uri: targetUri || undefined,
|
||||
use_stored_key: isEditing && !apiKeyChanged,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
const data = await testResponse.json();
|
||||
toast.error(data.detail || "Connection test failed");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Save the provider
|
||||
const response = await fetch("/api/admin/voice/providers", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
id: existingProvider?.id,
|
||||
name: label,
|
||||
provider_type: providerType,
|
||||
api_key: selectedLlmProviderId
|
||||
? undefined
|
||||
: apiKeyChanged
|
||||
? apiKey
|
||||
: undefined,
|
||||
api_key_changed: selectedLlmProviderId ? false : apiKeyChanged,
|
||||
target_uri: targetUri || undefined,
|
||||
llm_provider_id: selectedLlmProviderId,
|
||||
stt_model: sttModel,
|
||||
tts_model: ttsModel,
|
||||
default_voice: defaultVoice,
|
||||
activate_stt: mode === "stt",
|
||||
activate_tts: mode === "tts",
|
||||
}),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
toast.success(isEditing ? "Provider updated" : "Provider connected");
|
||||
onSuccess();
|
||||
} else {
|
||||
const data = await response.json();
|
||||
toast.error(data.detail || "Failed to save provider");
|
||||
}
|
||||
} catch {
|
||||
toast.error("Failed to save provider");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={LogoArrangement}
|
||||
title={isEditing ? `Edit ${label}` : `Set up ${label}`}
|
||||
description={`Connect to ${label} and set up your voice models.`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<Vertical
|
||||
title="API Key"
|
||||
subDescription={
|
||||
isEditing ? (
|
||||
"Leave blank to keep existing key"
|
||||
) : (
|
||||
<>
|
||||
Paste your{" "}
|
||||
<a
|
||||
href={PROVIDER_API_KEY_URLS[providerType]}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
API key
|
||||
</a>{" "}
|
||||
from {label} to access your models.
|
||||
</>
|
||||
)
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
{providerType === "openai" && existingApiKeyOptions.length > 0 ? (
|
||||
<InputComboBox
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
setSelectedLlmProviderId(null);
|
||||
}}
|
||||
onValueChange={(value) => {
|
||||
setApiKey(value);
|
||||
// Check if this is an existing key
|
||||
const llmProviderId = llmProviderMap.get(value);
|
||||
if (llmProviderId) {
|
||||
setSelectedLlmProviderId(llmProviderId);
|
||||
setApiKeyChanged(false);
|
||||
} else {
|
||||
setSelectedLlmProviderId(null);
|
||||
setApiKeyChanged(true);
|
||||
}
|
||||
}}
|
||||
options={existingApiKeyOptions}
|
||||
separatorLabel="Reuse OpenAI API Keys"
|
||||
strict={false}
|
||||
showAddPrefix
|
||||
/>
|
||||
) : (
|
||||
<InputTypeIn
|
||||
type="password"
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Vertical>
|
||||
|
||||
{providerType === "azure" && (
|
||||
<Vertical
|
||||
title="Target URI"
|
||||
subDescription={
|
||||
<>
|
||||
Paste the endpoint shown in{" "}
|
||||
<a
|
||||
href="https://portal.azure.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Azure Portal (Keys and Endpoint)
|
||||
</a>
|
||||
. Onyx extracts the speech region from this URL. Examples:
|
||||
https://westus.api.cognitive.microsoft.com/ or
|
||||
https://westus.tts.speech.microsoft.com/.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
placeholder={
|
||||
isEditing
|
||||
? "Leave blank to keep existing"
|
||||
: "https://<region>.api.cognitive.microsoft.com/"
|
||||
}
|
||||
value={targetUri}
|
||||
onChange={(e) => setTargetUri(e.target.value)}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "stt" && (
|
||||
<Horizontal title="STT Model" center nonInteractive>
|
||||
<InputSelect value={sttModel} onValueChange={setSttModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_STT_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Horizontal>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "tts" && (
|
||||
<Vertical
|
||||
title="Default Model"
|
||||
subDescription="This model will be used by Onyx by default for text-to-speech."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect value={ttsModel} onValueChange={setTtsModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_TTS_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{mode === "tts" && (
|
||||
<Vertical
|
||||
title="Voice"
|
||||
subDescription="This voice will be used for spoken responses."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect
|
||||
value={defaultVoice}
|
||||
onValueChange={setDefaultVoice}
|
||||
disabled={isLoadingVoices}
|
||||
>
|
||||
<InputSelect.Trigger
|
||||
placeholder={
|
||||
isLoadingVoices ? "Loading voices..." : "Select voice"
|
||||
}
|
||||
/>
|
||||
<InputSelect.Content>
|
||||
{voiceOptions.map((voice) => (
|
||||
<InputSelect.Item key={voice.value} value={voice.value}>
|
||||
{voice.label}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Vertical>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button secondary onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSubmit} disabled={isSubmitting}>
|
||||
{isSubmitting ? "Connecting..." : isEditing ? "Save" : "Connect"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
633
web/src/app/admin/configuration/voice/page.tsx
Normal file
633
web/src/app/admin/configuration/voice/page.tsx
Normal file
@@ -0,0 +1,633 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgAudio,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgMicrophone,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import VoiceProviderSetupModal from "./VoiceProviderSetupModal";
|
||||
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
interface VoiceProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: string;
|
||||
is_default_stt: boolean;
|
||||
is_default_tts: boolean;
|
||||
stt_model: string | null;
|
||||
tts_model: string | null;
|
||||
default_voice: string | null;
|
||||
has_api_key: boolean;
|
||||
target_uri: string | null;
|
||||
}
|
||||
|
||||
interface ModelDetails {
|
||||
id: string;
|
||||
label: string;
|
||||
subtitle: string;
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
interface ProviderGroup {
|
||||
providerType: string;
|
||||
providerLabel: string;
|
||||
logoSrc?: string;
|
||||
models: ModelDetails[];
|
||||
}
|
||||
|
||||
// STT Models - individual cards
|
||||
const STT_MODELS: ModelDetails[] = [
|
||||
{
|
||||
id: "whisper",
|
||||
label: "Whisper",
|
||||
subtitle: "OpenAI's general purpose speech recognition model.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "azure-speech-stt",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Speech to text in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
{
|
||||
id: "elevenlabs-stt",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Speech to Text API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
];
|
||||
|
||||
// TTS Models - grouped by provider
|
||||
const TTS_PROVIDER_GROUPS: ProviderGroup[] = [
|
||||
{
|
||||
providerType: "openai",
|
||||
providerLabel: "OpenAI",
|
||||
logoSrc: "/Openai.svg",
|
||||
models: [
|
||||
{
|
||||
id: "tts-1",
|
||||
label: "TTS-1",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for speed.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "tts-1-hd",
|
||||
label: "TTS-1 HD",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for quality.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "azure",
|
||||
providerLabel: "Azure",
|
||||
logoSrc: "/Azure.png",
|
||||
models: [
|
||||
{
|
||||
id: "azure-speech-tts",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Text to speech in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "elevenlabs",
|
||||
providerLabel: "ElevenLabs",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
models: [
|
||||
{
|
||||
id: "elevenlabs-tts",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Text to Speech API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
const [editingProvider, setEditingProvider] =
|
||||
useState<VoiceProviderView | null>(null);
|
||||
const [modalMode, setModalMode] = useState<ProviderMode>("stt");
|
||||
const [selectedModelId, setSelectedModelId] = useState<string | null>(null);
|
||||
const [sttActivationError, setSTTActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: providers,
|
||||
error,
|
||||
isLoading,
|
||||
mutate,
|
||||
} = useSWR<VoiceProviderView[]>(VOICE_PROVIDERS_URL, errorHandlingFetcher);
|
||||
|
||||
const handleConnect = (
|
||||
providerType: string,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(providerType);
|
||||
setEditingProvider(null);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
setSTTActivationError(null);
|
||||
setTTSActivationError(null);
|
||||
};
|
||||
|
||||
const handleEdit = (
|
||||
provider: VoiceProviderView,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(provider.provider_type);
|
||||
setEditingProvider(provider);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const handleSetDefault = async (
|
||||
providerId: number,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const url = new URL(
|
||||
`${VOICE_PROVIDERS_URL}/${providerId}/activate-${mode}`,
|
||||
window.location.origin
|
||||
);
|
||||
if (mode === "tts" && modelId) {
|
||||
url.searchParams.set("tts_model", modelId);
|
||||
}
|
||||
const response = await fetch(url.toString(), { method: "POST" });
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to set provider as default ${mode.toUpperCase()}.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeactivate = async (providerId: number, mode: ProviderMode) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${VOICE_PROVIDERS_URL}/${providerId}/deactivate-${mode}`,
|
||||
{ method: "POST" }
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to deactivate ${mode.toUpperCase()} provider.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalClose = () => {
|
||||
setModalOpen(false);
|
||||
setSelectedProvider(null);
|
||||
setEditingProvider(null);
|
||||
setSelectedModelId(null);
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
mutate();
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
|
||||
// Map provider types to their configured provider data
|
||||
const providersByType = useMemo(() => {
|
||||
return new Map((providers ?? []).map((p) => [p.provider_type, p] as const));
|
||||
}, [providers]);
|
||||
|
||||
const hasActiveSTTProvider =
|
||||
providers?.some((p) => p.is_default_stt) ?? false;
|
||||
const hasActiveTTSProvider =
|
||||
providers?.some((p) => p.is_default_tts) ?? false;
|
||||
|
||||
const renderLogo = ({
|
||||
logoSrc,
|
||||
alt,
|
||||
size = 16,
|
||||
}: {
|
||||
logoSrc?: string;
|
||||
alt: string;
|
||||
size?: number;
|
||||
}) => {
|
||||
const containerSizeClass = size === 24 ? "size-7" : "size-5";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center px-0.5 py-0 shrink-0 overflow-clip",
|
||||
containerSizeClass
|
||||
)}
|
||||
>
|
||||
{logoSrc ? (
|
||||
<Image src={logoSrc} alt={alt} width={size} height={size} />
|
||||
) : (
|
||||
<SvgMicrophone size={size} className="text-text-02" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderModelCard = ({
|
||||
model,
|
||||
mode,
|
||||
}: {
|
||||
model: ModelDetails;
|
||||
mode: ProviderMode;
|
||||
}) => {
|
||||
const provider = providersByType.get(model.providerType);
|
||||
const isConfigured = isProviderConfigured(provider);
|
||||
// For TTS, also check that this specific model is the default (not just the provider)
|
||||
const isActive =
|
||||
mode === "stt"
|
||||
? provider?.is_default_stt
|
||||
: provider?.is_default_tts && provider?.tts_model === model.id;
|
||||
const isHighlighted = isActive ?? false;
|
||||
const providerId = provider?.id;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: () => handleConnect(model.providerType, mode, model.id),
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
? () => handleDeactivate(providerId, mode)
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => handleSetDefault(providerId, mode, model.id)
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `${mode}-${model.id}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${mode}-${model.id}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-3 rounded-16 border p-2 bg-background-neutral-01",
|
||||
isHighlighted ? "border-action-link-05" : "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-1 p-2">
|
||||
{renderLogo({
|
||||
logoSrc: model.logoSrc,
|
||||
alt: `${model.label} logo`,
|
||||
size: 16,
|
||||
})}
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainUiAction text04>
|
||||
{model.label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{model.subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
aria-label={`Edit ${model.label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Button
|
||||
action={false}
|
||||
tertiary
|
||||
disabled={buttonState.disabled || !buttonState.onClick}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (error) {
|
||||
const message = error?.message || "Unable to load voice configuration.";
|
||||
const detail =
|
||||
error instanceof FetchError && typeof error.info?.detail === "string"
|
||||
? error.info.detail
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgAudio} title="Voice" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" secondaryBody text03>
|
||||
Speech to text (STT) and text to speech (TTS) capabilities.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
{/* Speech-to-Text Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Speech to Text
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to transcribe speech to text in chats.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{sttActivationError && (
|
||||
<Callout type="danger" title="Unable to update STT provider">
|
||||
{sttActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveSTTProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a speech to text provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{STT_MODELS.map((model) => renderModelCard({ model, mode: "stt" }))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Text-to-Speech Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Text to Speech
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to speak out chat responses.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{ttsActivationError && (
|
||||
<Callout type="danger" title="Unable to update TTS provider">
|
||||
{ttsActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveTTSProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a text to speech provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-4">
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text as="p" secondaryBody text03 className="px-0.5">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
{group.models.map((model) =>
|
||||
renderModelCard({ model, mode: "tts" })
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
existingProvider={editingProvider}
|
||||
mode={modalMode}
|
||||
defaultModelId={selectedModelId}
|
||||
onClose={handleModalClose}
|
||||
onSuccess={handleModalSuccess}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
import { SvgArrowUpRight, SvgUserSync } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Link from "next/link";
|
||||
import { ADMIN_PATHS } from "@/lib/admin-routes";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stats cell
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface StatCellProps {
|
||||
value: number | null;
|
||||
label: string;
|
||||
}
|
||||
|
||||
function StatCell({ value, label }: StatCellProps) {
|
||||
const display = value === null ? "—" : value.toLocaleString();
|
||||
|
||||
return (
|
||||
<Section alignItems="start" gap={0.25} width="fit" padding={0.5}>
|
||||
<Text as="p" headingH3 text05>
|
||||
{display}
|
||||
</Text>
|
||||
<Text as="p" mainUiMuted text03>
|
||||
{label}
|
||||
</Text>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SCIM card
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function ScimCard() {
|
||||
return (
|
||||
<Card gap={0.5} padding={0.75}>
|
||||
<ContentAction
|
||||
icon={SvgUserSync}
|
||||
title="SCIM Sync"
|
||||
description="Users are synced from your identity provider."
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
<Link href={ADMIN_PATHS.SCIM}>
|
||||
<Button prominence="tertiary" rightIcon={SvgArrowUpRight} size="sm">
|
||||
Manage
|
||||
</Button>
|
||||
</Link>
|
||||
}
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stats bar — layout varies by SCIM status
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface StatsBarProps {
|
||||
activeUsers: number | null;
|
||||
pendingInvites: number | null;
|
||||
requests: number | null;
|
||||
showScim: boolean;
|
||||
}
|
||||
|
||||
export default function StatsBar({
|
||||
activeUsers,
|
||||
pendingInvites,
|
||||
requests,
|
||||
showScim,
|
||||
}: StatsBarProps) {
|
||||
if (showScim) {
|
||||
// With SCIM: one card containing all 3 stats (dividers) + separate SCIM card
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
alignItems="stretch"
|
||||
gap={0.5}
|
||||
>
|
||||
<Card padding={0}>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
alignItems="stretch"
|
||||
gap={0}
|
||||
width="fit"
|
||||
height="auto"
|
||||
>
|
||||
<StatCell value={activeUsers} label="active users" />
|
||||
<Separator orientation="vertical" noPadding />
|
||||
<StatCell value={pendingInvites} label="pending invites" />
|
||||
{requests !== null && (
|
||||
<>
|
||||
<Separator orientation="vertical" noPadding />
|
||||
<StatCell value={requests} label="requests to join" />
|
||||
</>
|
||||
)}
|
||||
</Section>
|
||||
</Card>
|
||||
|
||||
<ScimCard />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
// Without SCIM: 3 separate cards
|
||||
return (
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="start"
|
||||
alignItems="stretch"
|
||||
gap={0.5}
|
||||
>
|
||||
<Card padding={0.5}>
|
||||
<StatCell value={activeUsers} label="active users" />
|
||||
</Card>
|
||||
<Card padding={0.5}>
|
||||
<StatCell value={pendingInvites} label="pending invites" />
|
||||
</Card>
|
||||
{requests !== null && (
|
||||
<Card padding={0.5}>
|
||||
<StatCell value={requests} label="requests to join" />
|
||||
</Card>
|
||||
)}
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
|
||||
import { SvgUser, SvgUserPlus } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { useScimToken } from "@/hooks/useScimToken";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import type { InvitedUserSnapshot } from "@/lib/types";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
|
||||
import StatsBar from "./StatsBar";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface PaginatedResponse {
|
||||
items: unknown[];
|
||||
total_items: number;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Users page content
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function UsersContent() {
|
||||
const isEe = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const { data: scimToken } = useScimToken();
|
||||
const showScim = isEe && !!scimToken;
|
||||
|
||||
// Active user count — lightweight fetch (page_size=1 to minimize payload)
|
||||
const { data: activeData } = useSWR<PaginatedResponse>(
|
||||
"/api/manage/users/accepted?page_num=0&page_size=1",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: invitedUsers } = useSWR<InvitedUserSnapshot[]>(
|
||||
"/api/manage/users/invited",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: pendingUsers } = useSWR<InvitedUserSnapshot[]>(
|
||||
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const activeCount = activeData?.total_items ?? null;
|
||||
const invitedCount = invitedUsers?.length ?? null;
|
||||
const pendingCount = pendingUsers?.length ?? null;
|
||||
|
||||
return (
|
||||
<>
|
||||
<StatsBar
|
||||
activeUsers={activeCount}
|
||||
pendingInvites={invitedCount}
|
||||
requests={pendingCount}
|
||||
showScim={showScim}
|
||||
/>
|
||||
|
||||
{/* Table and filters will be added in subsequent PRs */}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function Page() {
|
||||
// TODO (ENG-3806): Wire up invite modal in a future PR
|
||||
const [_showInviteModal, setShowInviteModal] = useState(false);
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root width="lg">
|
||||
<SettingsLayouts.Header
|
||||
title="Users & Requests"
|
||||
icon={SvgUser}
|
||||
rightChildren={
|
||||
<Button icon={SvgUserPlus} onClick={() => setShowInviteModal(true)}>
|
||||
Invite Users
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<UsersContent />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import type { Route } from "next";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { requireAuth } from "@/lib/auth/requireAuth";
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
|
||||
export interface LayoutProps {
|
||||
@@ -21,10 +22,12 @@ export default async function Layout({ children }: LayoutProps) {
|
||||
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
<VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</VoiceModeProvider>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, { useRef, RefObject, useMemo } from "react";
|
||||
import React, { useRef, RefObject, useMemo, useEffect } from "react";
|
||||
import { Packet, StopReason } from "@/app/app/services/streamingModels";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
import { FeedbackType } from "@/app/app/interfaces";
|
||||
@@ -14,6 +14,9 @@ import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { getTextContent } from "@/app/app/services/packetUtils";
|
||||
import { removeThinkingTokens } from "@/app/app/services/thinkingTokens";
|
||||
|
||||
// Type for the regeneration factory function passed from ChatUI
|
||||
export type RegenerationFactory = (regenerationRequest: {
|
||||
@@ -73,6 +76,7 @@ function arePropsEqual(
|
||||
|
||||
const AgentMessage = React.memo(function AgentMessage({
|
||||
rawPackets,
|
||||
packetCount,
|
||||
chatState,
|
||||
nodeId,
|
||||
messageId,
|
||||
@@ -158,6 +162,46 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onMessageSelection,
|
||||
});
|
||||
|
||||
// Streaming TTS integration
|
||||
const { streamTTS, resetTTS } = useVoiceMode();
|
||||
const ttsCompletedRef = useRef(false);
|
||||
const streamTTSRef = useRef(streamTTS);
|
||||
|
||||
// Keep streamTTS ref in sync without triggering effect re-runs
|
||||
useEffect(() => {
|
||||
streamTTSRef.current = streamTTS;
|
||||
}, [streamTTS]);
|
||||
|
||||
// Stream TTS as text content arrives - only for messages still streaming
|
||||
// Uses ref for streamTTS to avoid re-triggering when its identity changes
|
||||
// Note: packetCount is used instead of rawPackets because the array is mutated in place
|
||||
useEffect(() => {
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
const textContent = removeThinkingTokens(getTextContent(rawPackets));
|
||||
if (typeof textContent === "string" && textContent.length > 0) {
|
||||
streamTTSRef.current(textContent, isComplete);
|
||||
|
||||
// Mark as completed once the message is done streaming
|
||||
if (isComplete) {
|
||||
ttsCompletedRef.current = true;
|
||||
}
|
||||
}
|
||||
}, [packetCount, isComplete, rawPackets]); // packetCount triggers on new packets since rawPackets is mutated in place
|
||||
|
||||
// Reset TTS completed flag when nodeId changes (new message)
|
||||
useEffect(() => {
|
||||
ttsCompletedRef.current = false;
|
||||
}, [nodeId]);
|
||||
|
||||
// Reset TTS when component unmounts or nodeId changes
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
resetTTS();
|
||||
};
|
||||
}, [nodeId, resetTTS]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col gap-3"
|
||||
|
||||
@@ -29,6 +29,7 @@ import FeedbackModal, {
|
||||
FeedbackModalProps,
|
||||
} from "@/sections/modals/FeedbackModal";
|
||||
import { Button } from "@opal/components";
|
||||
import TTSButton from "./TTSButton";
|
||||
|
||||
// Wrapper component for SourceTag in toolbar to handle memoization
|
||||
const SourcesTagWrapper = React.memo(function SourcesTagWrapper({
|
||||
@@ -268,6 +269,9 @@ export default function MessageToolbar({
|
||||
}
|
||||
data-testid="AgentMessage/dislike-button"
|
||||
/>
|
||||
<TTSButton
|
||||
text={removeThinkingTokens(getTextContent(rawPackets)) as string}
|
||||
/>
|
||||
|
||||
{onRegenerate &&
|
||||
messageId !== undefined &&
|
||||
|
||||
84
web/src/app/app/message/messageComponents/TTSButton.tsx
Normal file
84
web/src/app/app/message/messageComponents/TTSButton.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { SvgPlayCircle, SvgStop } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { useVoicePlayback } from "@/hooks/useVoicePlayback";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
interface TTSButtonProps {
|
||||
text: string;
|
||||
voice?: string;
|
||||
speed?: number;
|
||||
}
|
||||
|
||||
function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
const { isPlaying, isLoading, error, play, pause, stop } = useVoicePlayback();
|
||||
const { isTTSPlaying, isTTSLoading, stopTTS } = useVoiceMode();
|
||||
|
||||
const isGlobalTTSActive = isTTSPlaying || isTTSLoading;
|
||||
const isButtonPlaying = isGlobalTTSActive || isPlaying;
|
||||
const isButtonLoading = !isGlobalTTSActive && isLoading;
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isGlobalTTSActive) {
|
||||
// Stop auto-playback voice mode stream from the toolbar button.
|
||||
stopTTS({ manual: true });
|
||||
stop();
|
||||
} else if (isPlaying) {
|
||||
pause();
|
||||
} else if (isButtonLoading) {
|
||||
stop();
|
||||
} else {
|
||||
try {
|
||||
// Ensure no voice-mode stream is active before starting manual playback.
|
||||
stopTTS();
|
||||
await play(text, voice, speed);
|
||||
} catch {
|
||||
toast.error("Could not play audio");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isGlobalTTSActive,
|
||||
isPlaying,
|
||||
isButtonLoading,
|
||||
text,
|
||||
voice,
|
||||
speed,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
stopTTS,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
const icon = isButtonLoading
|
||||
? SimpleLoader
|
||||
: isButtonPlaying
|
||||
? SvgStop
|
||||
: SvgPlayCircle;
|
||||
|
||||
const tooltip = isButtonPlaying
|
||||
? "Stop playback"
|
||||
: isButtonLoading
|
||||
? "Loading..."
|
||||
: "Read aloud";
|
||||
|
||||
return (
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
tooltip={tooltip}
|
||||
data-testid="AgentMessage/tts-button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TTSButton;
|
||||
@@ -566,6 +566,21 @@ textarea {
|
||||
animation: fadeIn 0.2s ease-out forwards;
|
||||
}
|
||||
|
||||
/* Recording waveform animation */
|
||||
@keyframes waveform {
|
||||
0%,
|
||||
100% {
|
||||
transform: scaleY(0.3);
|
||||
}
|
||||
50% {
|
||||
transform: scaleY(1);
|
||||
}
|
||||
}
|
||||
|
||||
.animate-waveform {
|
||||
animation: waveform 0.8s ease-in-out infinite;
|
||||
}
|
||||
|
||||
.container {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ const SETTINGS_LAYOUT_PREFIXES = [
|
||||
ADMIN_PATHS.LLM_MODELS,
|
||||
ADMIN_PATHS.AGENTS,
|
||||
ADMIN_PATHS.USERS,
|
||||
ADMIN_PATHS.USERS_V2,
|
||||
ADMIN_PATHS.TOKEN_RATE_LIMITS,
|
||||
ADMIN_PATHS.SEARCH_SETTINGS,
|
||||
ADMIN_PATHS.DOCUMENT_PROCESSING,
|
||||
|
||||
122
web/src/hooks/useVoicePlayback.ts
Normal file
122
web/src/hooks/useVoicePlayback.ts
Normal file
@@ -0,0 +1,122 @@
|
||||
import { useState, useRef, useCallback } from "react";
|
||||
|
||||
export interface UseVoicePlaybackReturn {
|
||||
isPlaying: boolean;
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
play: (text: string, voice?: string, speed?: number) => Promise<void>;
|
||||
pause: () => void;
|
||||
stop: () => void;
|
||||
}
|
||||
|
||||
export function useVoicePlayback(): UseVoicePlaybackReturn {
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const audioRef = useRef<HTMLAudioElement | null>(null);
|
||||
const audioUrlRef = useRef<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const stop = useCallback(() => {
|
||||
// Revoke object URL to prevent memory leak
|
||||
if (audioUrlRef.current) {
|
||||
URL.revokeObjectURL(audioUrlRef.current);
|
||||
audioUrlRef.current = null;
|
||||
}
|
||||
if (audioRef.current) {
|
||||
audioRef.current.pause();
|
||||
audioRef.current.src = "";
|
||||
audioRef.current = null;
|
||||
}
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
setIsPlaying(false);
|
||||
setIsLoading(false);
|
||||
}, []);
|
||||
|
||||
const pause = useCallback(() => {
|
||||
if (audioRef.current && isPlaying) {
|
||||
audioRef.current.pause();
|
||||
setIsPlaying(false);
|
||||
}
|
||||
}, [isPlaying]);
|
||||
|
||||
const play = useCallback(
|
||||
async (text: string, voice?: string, speed?: number) => {
|
||||
// Stop any existing playback
|
||||
stop();
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
abortControllerRef.current = new AbortController();
|
||||
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
if (speed !== undefined) params.set("speed", speed.toString());
|
||||
|
||||
const response = await fetch(`/api/voice/synthesize?${params}`, {
|
||||
method: "POST",
|
||||
signal: abortControllerRef.current.signal,
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.detail || "Speech synthesis failed");
|
||||
}
|
||||
|
||||
const audioBlob = await response.blob();
|
||||
const audioUrl = URL.createObjectURL(audioBlob);
|
||||
audioUrlRef.current = audioUrl;
|
||||
|
||||
const audio = new Audio(audioUrl);
|
||||
audioRef.current = audio;
|
||||
|
||||
audio.onended = () => {
|
||||
setIsPlaying(false);
|
||||
if (audioUrlRef.current) {
|
||||
URL.revokeObjectURL(audioUrlRef.current);
|
||||
audioUrlRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
audio.onerror = () => {
|
||||
setError("Audio playback failed");
|
||||
setIsPlaying(false);
|
||||
if (audioUrlRef.current) {
|
||||
URL.revokeObjectURL(audioUrlRef.current);
|
||||
audioUrlRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
setIsLoading(false);
|
||||
setIsPlaying(true);
|
||||
await audio.play();
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// Request was cancelled, not an error
|
||||
return;
|
||||
}
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Speech synthesis failed";
|
||||
setError(message);
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
[stop]
|
||||
);
|
||||
|
||||
return {
|
||||
isPlaying,
|
||||
isLoading,
|
||||
error,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
};
|
||||
}
|
||||
462
web/src/hooks/useVoiceRecorder.ts
Normal file
462
web/src/hooks/useVoiceRecorder.ts
Normal file
@@ -0,0 +1,462 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
// Target format for OpenAI Realtime API
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const SILENCE_TIMEOUT_MS = 5000; // Stop recording if no speech detected for 5 seconds
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
text?: string;
|
||||
message?: string;
|
||||
is_final?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderOptions {
|
||||
/** Called when VAD detects silence and final transcript is received */
|
||||
onFinalTranscript?: (text: string) => void;
|
||||
/** If true, automatically stop recording when VAD detects silence */
|
||||
autoStopOnSilence?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderReturn {
|
||||
isRecording: boolean;
|
||||
isProcessing: boolean;
|
||||
isMuted: boolean;
|
||||
error: string | null;
|
||||
liveTranscript: string;
|
||||
startRecording: () => Promise<void>;
|
||||
stopRecording: () => Promise<string | null>;
|
||||
setMuted: (muted: boolean) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Encapsulates all browser resources for a voice recording session.
|
||||
* Manages WebSocket, Web Audio API, and audio buffering.
|
||||
*/
|
||||
class VoiceRecorderSession {
|
||||
// Browser resources
|
||||
private websocket: WebSocket | null = null;
|
||||
private audioContext: AudioContext | null = null;
|
||||
private scriptNode: ScriptProcessorNode | null = null;
|
||||
private sourceNode: MediaStreamAudioSourceNode | null = null;
|
||||
private mediaStream: MediaStream | null = null;
|
||||
private sendInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
// State
|
||||
private audioBuffer: Float32Array[] = [];
|
||||
private transcript = "";
|
||||
private stopResolver: ((text: string | null) => void) | null = null;
|
||||
private isActive = false;
|
||||
private silenceTimeout: NodeJS.Timeout | null = null;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
private onFinalTranscript: ((text: string) => void) | null;
|
||||
private onError: (error: string) => void;
|
||||
private onSilenceTimeout: (() => void) | null;
|
||||
private onVADStop: (() => void) | null;
|
||||
private autoStopOnSilence: boolean;
|
||||
|
||||
constructor(
|
||||
onTranscriptChange: (text: string) => void,
|
||||
onFinalTranscript: ((text: string) => void) | null,
|
||||
onError: (error: string) => void,
|
||||
onSilenceTimeout?: () => void,
|
||||
autoStopOnSilence?: boolean,
|
||||
onVADStop?: () => void
|
||||
) {
|
||||
this.onTranscriptChange = onTranscriptChange;
|
||||
this.onFinalTranscript = onFinalTranscript;
|
||||
this.onError = onError;
|
||||
this.onSilenceTimeout = onSilenceTimeout || null;
|
||||
this.autoStopOnSilence = autoStopOnSilence ?? false;
|
||||
this.onVADStop = onVADStop || null;
|
||||
}
|
||||
|
||||
get recording(): boolean {
|
||||
return this.isActive;
|
||||
}
|
||||
|
||||
get currentTranscript(): string {
|
||||
return this.transcript;
|
||||
}
|
||||
|
||||
setMuted(muted: boolean): void {
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getAudioTracks().forEach((track) => {
|
||||
track.enabled = !muted;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.isActive) return;
|
||||
|
||||
this.cleanup();
|
||||
this.transcript = "";
|
||||
this.audioBuffer = [];
|
||||
|
||||
// Get microphone
|
||||
this.mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
channelCount: 1,
|
||||
sampleRate: { ideal: TARGET_SAMPLE_RATE },
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
},
|
||||
});
|
||||
|
||||
// Get WS token and connect WebSocket
|
||||
const wsUrl = await this.getWebSocketUrl();
|
||||
this.websocket = new WebSocket(wsUrl);
|
||||
this.websocket.onmessage = this.handleMessage;
|
||||
this.websocket.onerror = () => this.onError("Connection failed");
|
||||
this.websocket.onclose = () => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
};
|
||||
|
||||
await this.waitForConnection();
|
||||
|
||||
// Set up audio capture
|
||||
this.audioContext = new AudioContext({ sampleRate: TARGET_SAMPLE_RATE });
|
||||
this.sourceNode = this.audioContext.createMediaStreamSource(
|
||||
this.mediaStream
|
||||
);
|
||||
this.scriptNode = this.audioContext.createScriptProcessor(4096, 1, 1);
|
||||
|
||||
this.scriptNode.onaudioprocess = (event) => {
|
||||
const inputData = event.inputBuffer.getChannelData(0);
|
||||
this.audioBuffer.push(new Float32Array(inputData));
|
||||
};
|
||||
|
||||
this.sourceNode.connect(this.scriptNode);
|
||||
this.scriptNode.connect(this.audioContext.destination);
|
||||
|
||||
// Start sending audio chunks
|
||||
this.sendInterval = setInterval(
|
||||
() => this.sendAudioBuffer(),
|
||||
CHUNK_INTERVAL_MS
|
||||
);
|
||||
this.isActive = true;
|
||||
}
|
||||
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
this.sendInterval = null;
|
||||
}
|
||||
if (this.scriptNode) {
|
||||
this.scriptNode.disconnect();
|
||||
this.scriptNode = null;
|
||||
}
|
||||
if (this.sourceNode) {
|
||||
this.sourceNode.disconnect();
|
||||
this.sourceNode = null;
|
||||
}
|
||||
if (this.audioContext) {
|
||||
this.audioContext.close();
|
||||
this.audioContext = null;
|
||||
}
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getTracks().forEach((track) => track.stop());
|
||||
this.mediaStream = null;
|
||||
}
|
||||
|
||||
this.audioBuffer = [];
|
||||
this.isActive = false;
|
||||
|
||||
// Get final transcript from server
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
return new Promise((resolve) => {
|
||||
this.stopResolver = resolve;
|
||||
this.websocket!.send(JSON.stringify({ type: "end" }));
|
||||
|
||||
// Timeout fallback
|
||||
setTimeout(() => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}, 3000);
|
||||
});
|
||||
}
|
||||
|
||||
return this.transcript || null;
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
if (this.audioContext) this.audioContext.close();
|
||||
if (this.mediaStream) this.mediaStream.getTracks().forEach((t) => t.stop());
|
||||
if (this.websocket) this.websocket.close();
|
||||
|
||||
this.sendInterval = null;
|
||||
this.scriptNode = null;
|
||||
this.sourceNode = null;
|
||||
this.audioContext = null;
|
||||
this.mediaStream = null;
|
||||
this.websocket = null;
|
||||
this.isActive = false;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/transcribe/stream"
|
||||
: "/api/voice/transcribe/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
private waitForConnection(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!this.websocket) return reject(new Error("No WebSocket"));
|
||||
|
||||
const timeout = setTimeout(
|
||||
() => reject(new Error("Connection timeout")),
|
||||
5000
|
||||
);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
resolve();
|
||||
};
|
||||
this.websocket.onerror = () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Connection failed"));
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback
|
||||
if (this.onFinalTranscript) {
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
this.onError(data.message || "Transcription error");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to parse transcript message:", e);
|
||||
}
|
||||
};
|
||||
|
||||
private resetBackendTranscript(): void {
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "reset" }));
|
||||
}
|
||||
}
|
||||
|
||||
private sendAudioBuffer(): void {
|
||||
if (
|
||||
!this.websocket ||
|
||||
this.websocket.readyState !== WebSocket.OPEN ||
|
||||
!this.audioContext ||
|
||||
this.audioBuffer.length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Concatenate buffered chunks
|
||||
const totalLength = this.audioBuffer.reduce(
|
||||
(sum, chunk) => sum + chunk.length,
|
||||
0
|
||||
);
|
||||
|
||||
// Prevent buffer overflow
|
||||
if (totalLength > this.audioContext.sampleRate * 0.5 * 2) {
|
||||
this.audioBuffer = this.audioBuffer.slice(-10);
|
||||
return;
|
||||
}
|
||||
|
||||
const concatenated = new Float32Array(totalLength);
|
||||
let offset = 0;
|
||||
for (const chunk of this.audioBuffer) {
|
||||
concatenated.set(chunk, offset);
|
||||
offset += chunk.length;
|
||||
}
|
||||
this.audioBuffer = [];
|
||||
|
||||
// Resample and convert to PCM16
|
||||
const resampled = this.resampleAudio(
|
||||
concatenated,
|
||||
this.audioContext.sampleRate
|
||||
);
|
||||
const pcm16 = this.float32ToInt16(resampled);
|
||||
|
||||
this.websocket.send(pcm16.buffer);
|
||||
}
|
||||
|
||||
private resampleAudio(input: Float32Array, inputRate: number): Float32Array {
|
||||
if (inputRate === TARGET_SAMPLE_RATE) return input;
|
||||
|
||||
const ratio = inputRate / TARGET_SAMPLE_RATE;
|
||||
const outputLength = Math.round(input.length / ratio);
|
||||
const output = new Float32Array(outputLength);
|
||||
|
||||
for (let i = 0; i < outputLength; i++) {
|
||||
const srcIndex = i * ratio;
|
||||
const floor = Math.floor(srcIndex);
|
||||
const ceil = Math.min(floor + 1, input.length - 1);
|
||||
const fraction = srcIndex - floor;
|
||||
output[i] = input[floor]! * (1 - fraction) + input[ceil]! * fraction;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private float32ToInt16(float32: Float32Array): Int16Array {
|
||||
const int16 = new Int16Array(float32.length);
|
||||
for (let i = 0; i < float32.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, float32[i]!));
|
||||
int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||
}
|
||||
return int16;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for voice recording with streaming transcription.
|
||||
*/
|
||||
export function useVoiceRecorder(
|
||||
options?: UseVoiceRecorderOptions
|
||||
): UseVoiceRecorderReturn {
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [isMuted, setIsMutedState] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [liveTranscript, setLiveTranscript] = useState("");
|
||||
|
||||
const sessionRef = useRef<VoiceRecorderSession | null>(null);
|
||||
const onFinalTranscriptRef = useRef(options?.onFinalTranscript);
|
||||
const autoStopOnSilenceRef = useRef(options?.autoStopOnSilence ?? true); // Default to true
|
||||
|
||||
// Keep callback ref in sync
|
||||
useEffect(() => {
|
||||
onFinalTranscriptRef.current = options?.onFinalTranscript;
|
||||
autoStopOnSilenceRef.current = options?.autoStopOnSilence ?? true;
|
||||
}, [options?.onFinalTranscript, options?.autoStopOnSilence]);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
sessionRef.current?.cleanup();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const startRecording = useCallback(async () => {
|
||||
if (sessionRef.current?.recording) return;
|
||||
|
||||
setError(null);
|
||||
setLiveTranscript("");
|
||||
|
||||
// Create VAD stop handler that will stop the session
|
||||
const handleVADStop = () => {
|
||||
if (sessionRef.current) {
|
||||
sessionRef.current.stop().then(() => {
|
||||
setIsRecording(false);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
sessionRef.current = new VoiceRecorderSession(
|
||||
setLiveTranscript,
|
||||
(text) => onFinalTranscriptRef.current?.(text),
|
||||
setError,
|
||||
undefined, // onSilenceTimeout
|
||||
autoStopOnSilenceRef.current,
|
||||
handleVADStop
|
||||
);
|
||||
|
||||
try {
|
||||
await sessionRef.current.start();
|
||||
setIsRecording(true);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to start recording"
|
||||
);
|
||||
throw err;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const stopRecording = useCallback(async (): Promise<string | null> => {
|
||||
if (!sessionRef.current) return null;
|
||||
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const transcript = await sessionRef.current.stop();
|
||||
return transcript;
|
||||
} finally {
|
||||
setIsRecording(false);
|
||||
setIsProcessing(false);
|
||||
setIsMutedState(false); // Reset mute state when recording stops
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setMuted = useCallback((muted: boolean) => {
|
||||
setIsMutedState(muted);
|
||||
sessionRef.current?.setMuted(muted);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
};
|
||||
}
|
||||
149
web/src/hooks/useWebSocket.ts
Normal file
149
web/src/hooks/useWebSocket.ts
Normal file
@@ -0,0 +1,149 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
export type WebSocketStatus =
|
||||
| "connecting"
|
||||
| "connected"
|
||||
| "disconnected"
|
||||
| "error";
|
||||
|
||||
export interface UseWebSocketOptions<T> {
|
||||
/** URL to connect to */
|
||||
url: string;
|
||||
/** Called when a message is received */
|
||||
onMessage?: (data: T) => void;
|
||||
/** Called when connection opens */
|
||||
onOpen?: () => void;
|
||||
/** Called when connection closes */
|
||||
onClose?: () => void;
|
||||
/** Called on error */
|
||||
onError?: (error: Event) => void;
|
||||
/** Auto-connect on mount */
|
||||
autoConnect?: boolean;
|
||||
}
|
||||
|
||||
export interface UseWebSocketReturn<T> {
|
||||
/** Current connection status */
|
||||
status: WebSocketStatus;
|
||||
/** Send JSON data */
|
||||
sendJson: (data: T) => void;
|
||||
/** Send binary data */
|
||||
sendBinary: (data: Blob | ArrayBuffer) => void;
|
||||
/** Connect to WebSocket */
|
||||
connect: () => Promise<void>;
|
||||
/** Disconnect from WebSocket */
|
||||
disconnect: () => void;
|
||||
}
|
||||
|
||||
export function useWebSocket<TReceive = unknown, TSend = unknown>({
|
||||
url,
|
||||
onMessage,
|
||||
onOpen,
|
||||
onClose,
|
||||
onError,
|
||||
autoConnect = false,
|
||||
}: UseWebSocketOptions<TReceive>): UseWebSocketReturn<TSend> {
|
||||
const [status, setStatus] = useState<WebSocketStatus>("disconnected");
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const onMessageRef = useRef(onMessage);
|
||||
const onOpenRef = useRef(onOpen);
|
||||
const onCloseRef = useRef(onClose);
|
||||
const onErrorRef = useRef(onError);
|
||||
|
||||
// Keep refs updated
|
||||
useEffect(() => {
|
||||
onMessageRef.current = onMessage;
|
||||
onOpenRef.current = onOpen;
|
||||
onCloseRef.current = onClose;
|
||||
onErrorRef.current = onError;
|
||||
}, [onMessage, onOpen, onClose, onError]);
|
||||
|
||||
const connect = useCallback(async (): Promise<void> => {
|
||||
if (
|
||||
wsRef.current?.readyState === WebSocket.OPEN ||
|
||||
wsRef.current?.readyState === WebSocket.CONNECTING
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus("connecting");
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const ws = new WebSocket(url);
|
||||
wsRef.current = ws;
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
ws.close();
|
||||
reject(new Error("WebSocket connection timeout"));
|
||||
}, 10000);
|
||||
|
||||
ws.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("connected");
|
||||
onOpenRef.current?.();
|
||||
resolve();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data) as TReceive;
|
||||
onMessageRef.current?.(data);
|
||||
} catch {
|
||||
// Non-JSON message, ignore or handle differently
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
setStatus("disconnected");
|
||||
onCloseRef.current?.();
|
||||
wsRef.current = null;
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("error");
|
||||
onErrorRef.current?.(error);
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
});
|
||||
}, [url]);
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
wsRef.current = null;
|
||||
}
|
||||
setStatus("disconnected");
|
||||
}, []);
|
||||
|
||||
const sendJson = useCallback((data: TSend) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify(data));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const sendBinary = useCallback((data: Blob | ArrayBuffer) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(data);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Auto-connect if enabled
|
||||
useEffect(() => {
|
||||
if (autoConnect) {
|
||||
connect().catch(() => {
|
||||
// Error handled via onError callback
|
||||
});
|
||||
}
|
||||
return () => {
|
||||
disconnect();
|
||||
};
|
||||
}, [autoConnect, connect, disconnect]);
|
||||
|
||||
return {
|
||||
status,
|
||||
sendJson,
|
||||
sendBinary,
|
||||
connect,
|
||||
disconnect,
|
||||
};
|
||||
}
|
||||
@@ -86,7 +86,7 @@ function SettingsRoot({ width = "md", ...props }: SettingsRootProps) {
|
||||
return (
|
||||
<div
|
||||
id="page-wrapper-scroll-container"
|
||||
className="w-full h-full flex flex-col items-center overflow-y-auto pt-10"
|
||||
className="w-full h-full flex flex-col items-center overflow-y-auto"
|
||||
>
|
||||
{/* WARNING: The id="page-wrapper-scroll-container" above is used by SettingsHeader
|
||||
to detect scroll position and show/hide the scroll shadow.
|
||||
|
||||
@@ -58,7 +58,6 @@ export const ADMIN_PATHS = {
|
||||
DOCUMENT_PROCESSING: "/admin/configuration/document-processing",
|
||||
KNOWLEDGE_GRAPH: "/admin/kg",
|
||||
USERS: "/admin/users",
|
||||
USERS_V2: "/admin/users2",
|
||||
API_KEYS: "/admin/api-key",
|
||||
TOKEN_RATE_LIMITS: "/admin/token-rate-limits",
|
||||
USAGE: "/admin/performance/usage",
|
||||
@@ -191,11 +190,6 @@ export const ADMIN_ROUTE_CONFIG: Record<string, AdminRouteConfig> = {
|
||||
title: "Manage Users",
|
||||
sidebarLabel: "Users",
|
||||
},
|
||||
[ADMIN_PATHS.USERS_V2]: {
|
||||
icon: SvgUser,
|
||||
title: "Users & Requests",
|
||||
sidebarLabel: "Users v2",
|
||||
},
|
||||
[ADMIN_PATHS.API_KEYS]: {
|
||||
icon: SvgKey,
|
||||
title: "API Keys",
|
||||
|
||||
37
web/src/lib/sentenceDetector.ts
Normal file
37
web/src/lib/sentenceDetector.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Sentence detection for streaming TTS using 'sbd' library.
|
||||
*/
|
||||
|
||||
import sbd from "sbd";
|
||||
|
||||
/**
|
||||
* Split text into sentences. Returns complete sentences and remaining buffer.
|
||||
*/
|
||||
export function detectSentences(
|
||||
text: string,
|
||||
isComplete: boolean = false
|
||||
): { sentences: string[]; buffer: string } {
|
||||
const sentences = sbd.sentences(text, {
|
||||
newline_boundaries: true,
|
||||
html_boundaries: false,
|
||||
sanitize: false,
|
||||
});
|
||||
|
||||
if (sentences.length === 0) {
|
||||
return { sentences: [], buffer: text };
|
||||
}
|
||||
|
||||
// Check if text ends with sentence-ending punctuation
|
||||
const endsWithPunctuation = /[.!?]["']?\s*$/.test(text.trim());
|
||||
|
||||
if (isComplete || endsWithPunctuation) {
|
||||
// All sentences are complete
|
||||
return { sentences: sentences.map((s) => s.trim()), buffer: "" };
|
||||
}
|
||||
|
||||
// Last sentence might be incomplete - keep it in buffer
|
||||
const complete = sentences.slice(0, -1).map((s) => s.trim());
|
||||
const buffer = sentences[sentences.length - 1] ?? "";
|
||||
|
||||
return { sentences: complete, buffer };
|
||||
}
|
||||
588
web/src/lib/streamingTTS.ts
Normal file
588
web/src/lib/streamingTTS.ts
Normal file
@@ -0,0 +1,588 @@
|
||||
/**
|
||||
* Real-time streaming TTS using HTTP streaming with MediaSource Extensions.
|
||||
* Plays audio chunks as they arrive for smooth, low-latency playback.
|
||||
*/
|
||||
|
||||
/**
|
||||
* HTTPStreamingTTSPlayer - Uses HTTP streaming with MediaSource Extensions
|
||||
* for smooth, gapless audio playback. This is the recommended approach for
|
||||
* real-time TTS as it properly handles MP3 frame boundaries.
|
||||
*/
|
||||
export class HTTPStreamingTTSPlayer {
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private streamComplete: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private abortController: AbortController | null = null;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private getAPIUrl(): string {
|
||||
// Always go through the frontend proxy to ensure cookies are sent correctly
|
||||
// The Next.js proxy at /api/* forwards to the backend
|
||||
return "/api/voice/synthesize";
|
||||
}
|
||||
|
||||
/**
|
||||
* Speak text using HTTP streaming with real-time playback.
|
||||
* Audio begins playing as soon as the first chunks arrive.
|
||||
*/
|
||||
async speak(
|
||||
text: string,
|
||||
voice?: string,
|
||||
speed: number = 1.0
|
||||
): Promise<void> {
|
||||
// Cleanup any previous playback
|
||||
this.cleanup();
|
||||
|
||||
// Create abort controller for this request
|
||||
this.abortController = new AbortController();
|
||||
|
||||
// Build URL with query params
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
params.set("speed", speed.toString());
|
||||
|
||||
const url = `${this.getAPIUrl()}?${params}`;
|
||||
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
// Fallback to simple buffered playback
|
||||
return this.fallbackSpeak(url);
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.audioElement.src = URL.createObjectURL(this.mediaSource);
|
||||
|
||||
// Set up audio element event handlers
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
if (!this.mediaSource) {
|
||||
reject(new Error("MediaSource not initialized"));
|
||||
return;
|
||||
}
|
||||
|
||||
this.mediaSource.onsourceopen = () => {
|
||||
try {
|
||||
// Create SourceBuffer for MP3
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
|
||||
// MediaSource doesn't have onerror in all browsers, use onsourceclose as fallback
|
||||
this.mediaSource.onsourceclose = () => {
|
||||
if (this.mediaSource?.readyState === "closed") {
|
||||
reject(new Error("MediaSource closed unexpectedly"));
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Start fetching and streaming audio
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`TTS request failed: ${response.status} - ${errorText}`
|
||||
);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Start playback as soon as we have some data
|
||||
let firstChunk = true;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
this.streamComplete = true;
|
||||
// End the stream when all chunks are appended
|
||||
this.finalizeStream();
|
||||
break;
|
||||
}
|
||||
|
||||
if (value) {
|
||||
this.pendingChunks.push(value);
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (firstChunk && this.audioElement) {
|
||||
firstChunk = false;
|
||||
// Small delay to buffer a bit before starting
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback start errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
this.onError?.(err instanceof Error ? err.message : "TTS error");
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process next chunk from the queue.
|
||||
*/
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
// Try next chunk
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize the stream when all data has been received.
|
||||
*/
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
// Wait for remaining chunks to be appended
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore errors when ending stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback for browsers that don't support MediaSource Extensions.
|
||||
* Buffers all audio before playing.
|
||||
*/
|
||||
private async fallbackSpeak(url: string): Promise<void> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController?.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const audioData = await response.arrayBuffer();
|
||||
|
||||
const blob = new Blob([audioData], { type: "audio/mpeg" });
|
||||
const audioUrl = URL.createObjectURL(blob);
|
||||
|
||||
this.audioElement = new Audio(audioUrl);
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
URL.revokeObjectURL(audioUrl);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
};
|
||||
|
||||
await this.audioElement.play();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop playback and cleanup resources.
|
||||
*/
|
||||
stop(): void {
|
||||
// Abort any ongoing request
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
}
|
||||
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup all resources.
|
||||
*/
|
||||
private cleanup(): void {
|
||||
// Stop and cleanup audio element
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
// Cleanup MediaSource
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.streamComplete = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* WebSocketStreamingTTSPlayer - Uses WebSocket for bidirectional streaming.
|
||||
* Useful for scenarios where you want to stream text in and get audio out
|
||||
* incrementally (e.g., as LLM generates text).
|
||||
*/
|
||||
export class WebSocketStreamingTTSPlayer {
|
||||
private websocket: WebSocket | null = null;
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private hasStartedPlayback: boolean = false;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/synthesize/stream"
|
||||
: "/api/voice/synthesize/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
async connect(voice?: string, speed?: number): Promise<void> {
|
||||
// Cleanup any previous connection
|
||||
this.cleanup();
|
||||
|
||||
// Check MediaSource support
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
throw new Error("MediaSource Extensions not supported");
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.audioElement.src = URL.createObjectURL(this.mediaSource);
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
this.mediaSource!.onsourceopen = () => {
|
||||
try {
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Connect WebSocket
|
||||
const url = await this.getWebSocketUrl();
|
||||
return new Promise((resolve, reject) => {
|
||||
this.websocket = new WebSocket(url);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
// Send initial config
|
||||
this.websocket?.send(
|
||||
JSON.stringify({
|
||||
type: "config",
|
||||
voice: voice,
|
||||
speed: speed || 1.0,
|
||||
})
|
||||
);
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.websocket.onerror = () => {
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
|
||||
this.websocket.onmessage = async (event) => {
|
||||
if (event.data instanceof Blob) {
|
||||
// Audio chunk received
|
||||
const arrayBuffer = await event.data.arrayBuffer();
|
||||
this.pendingChunks.push(new Uint8Array(arrayBuffer));
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (!this.hasStartedPlayback && this.audioElement) {
|
||||
this.hasStartedPlayback = true;
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
} else {
|
||||
// JSON message
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "audio_done") {
|
||||
this.finalizeStream();
|
||||
} else if (data.type === "error") {
|
||||
this.onError?.(data.message);
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
this.websocket.onclose = () => {
|
||||
this.finalizeStream();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async speak(text: string): Promise<void> {
|
||||
if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
|
||||
throw new Error("WebSocket not connected");
|
||||
}
|
||||
|
||||
this.websocket.send(
|
||||
JSON.stringify({
|
||||
type: "synthesize",
|
||||
text: text,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
stop(): void {
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "end" }));
|
||||
this.websocket.close();
|
||||
}
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
private cleanup(): void {
|
||||
if (this.websocket) {
|
||||
this.websocket.close();
|
||||
this.websocket = null;
|
||||
}
|
||||
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.hasStartedPlayback = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
// Export the HTTP player as the default/recommended option
|
||||
export { HTTPStreamingTTSPlayer as StreamingTTSPlayer };
|
||||
@@ -32,6 +32,11 @@ interface UserPreferences {
|
||||
theme_preference: ThemePreference | null;
|
||||
chat_background: string | null;
|
||||
default_app_mode: "AUTO" | "CHAT" | "SEARCH";
|
||||
// Voice preferences
|
||||
voice_auto_send?: boolean;
|
||||
voice_auto_playback?: boolean;
|
||||
voice_playback_speed?: number;
|
||||
preferred_voice?: string;
|
||||
}
|
||||
|
||||
export interface MemoryItem {
|
||||
|
||||
@@ -46,6 +46,12 @@ interface UserContextType {
|
||||
updateUserChatBackground: (chatBackground: string | null) => Promise<void>;
|
||||
updateUserDefaultModel: (defaultModel: string | null) => Promise<void>;
|
||||
updateUserDefaultAppMode: (mode: "CHAT" | "SEARCH") => Promise<void>;
|
||||
updateUserVoiceSettings: (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
preferred_voice?: string;
|
||||
}) => Promise<void>;
|
||||
}
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
@@ -460,6 +466,54 @@ export function UserProvider({
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserVoiceSettings = async (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
preferred_voice?: string;
|
||||
}) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
if (prevUser) {
|
||||
return {
|
||||
...prevUser,
|
||||
preferences: {
|
||||
...prevUser.preferences,
|
||||
voice_auto_send:
|
||||
settings.auto_send ?? prevUser.preferences.voice_auto_send,
|
||||
voice_auto_playback:
|
||||
settings.auto_playback ??
|
||||
prevUser.preferences.voice_auto_playback,
|
||||
voice_playback_speed:
|
||||
settings.playback_speed ??
|
||||
prevUser.preferences.voice_playback_speed,
|
||||
preferred_voice:
|
||||
settings.preferred_voice ??
|
||||
prevUser.preferences.preferred_voice,
|
||||
},
|
||||
};
|
||||
}
|
||||
return prevUser;
|
||||
});
|
||||
|
||||
const response = await fetch("/api/voice/settings", {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
await refreshUser();
|
||||
throw new Error("Failed to update voice settings");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating voice settings:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const refreshUser = async () => {
|
||||
await fetchUser();
|
||||
};
|
||||
@@ -478,6 +532,7 @@ export function UserProvider({
|
||||
updateUserChatBackground,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
toggleAgentPinnedStatus,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
|
||||
771
web/src/providers/VoiceModeProvider.tsx
Normal file
771
web/src/providers/VoiceModeProvider.tsx
Normal file
@@ -0,0 +1,771 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useCallback,
|
||||
useRef,
|
||||
useEffect,
|
||||
} from "react";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
|
||||
interface VoiceModeContextType {
|
||||
/** Whether TTS audio is currently playing */
|
||||
isTTSPlaying: boolean;
|
||||
/** Whether TTS is loading/generating audio */
|
||||
isTTSLoading: boolean;
|
||||
/** Text that has been spoken so far (for synced display) */
|
||||
spokenText: string;
|
||||
/** Stream text for TTS - speaks sentences as they complete */
|
||||
streamTTS: (text: string, isComplete?: boolean) => void;
|
||||
/** Stop TTS playback */
|
||||
stopTTS: (options?: { manual?: boolean }) => void;
|
||||
/** Increments when TTS is manually stopped by the user */
|
||||
manualStopCount: number;
|
||||
/** Reset state for new message */
|
||||
resetTTS: () => void;
|
||||
}
|
||||
|
||||
const VoiceModeContext = createContext<VoiceModeContextType | null>(null);
|
||||
|
||||
/**
|
||||
* Clean text for TTS - remove markdown formatting
|
||||
*/
|
||||
function cleanTextForTTS(text: string): string {
|
||||
return text
|
||||
.replace(/\*\*/g, "") // Remove bold markers
|
||||
.replace(/\*/g, "") // Remove italic markers
|
||||
.replace(/`{1,3}/g, "") // Remove code markers
|
||||
.replace(/#{1,6}\s*/g, "") // Remove headers
|
||||
.replace(/\[([^\]]+)\]\([^)]+\)/g, "$1") // Convert links to just text
|
||||
.replace(/\n+/g, " ") // Replace newlines with spaces
|
||||
.replace(/\s+/g, " ") // Normalize whitespace
|
||||
.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the next natural chunk boundary in text.
|
||||
* Prefers sentence endings for natural speech rhythm.
|
||||
*/
|
||||
function findChunkBoundary(text: string): number {
|
||||
// Look for sentence endings (. ! ?) - these are natural speech breaks
|
||||
const sentenceRegex = /[.!?](?:\s|$)/g;
|
||||
let match;
|
||||
let lastSentenceEnd = -1;
|
||||
|
||||
while ((match = sentenceRegex.exec(text)) !== null) {
|
||||
const endPos = match.index + 1;
|
||||
if (endPos >= 10) {
|
||||
lastSentenceEnd = endPos;
|
||||
if (endPos >= 30) return endPos;
|
||||
}
|
||||
}
|
||||
|
||||
if (lastSentenceEnd > 0) return lastSentenceEnd;
|
||||
|
||||
// Only break at clauses for very long text (150+ chars)
|
||||
if (text.length >= 150) {
|
||||
const clauseRegex = /[,;:]\s/g;
|
||||
while ((match = clauseRegex.exec(text)) !== null) {
|
||||
const endPos = match.index + 1;
|
||||
if (endPos >= 70) return endPos;
|
||||
}
|
||||
}
|
||||
|
||||
// Break at word boundary for extremely long text (200+ chars)
|
||||
if (text.length >= 200) {
|
||||
const spaceIndex = text.lastIndexOf(" ", 120);
|
||||
if (spaceIndex > 80) return spaceIndex;
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
export function VoiceModeProvider({ children }: { children: React.ReactNode }) {
|
||||
const { user } = useUser();
|
||||
const autoPlayback = user?.preferences?.voice_auto_playback ?? false;
|
||||
const playbackSpeed = user?.preferences?.voice_playback_speed ?? 1.0;
|
||||
const preferredVoice = user?.preferences?.preferred_voice;
|
||||
|
||||
const [isTTSPlaying, setIsTTSPlaying] = useState(false);
|
||||
const [isTTSLoading, setIsTTSLoading] = useState(false);
|
||||
const [spokenText, setSpokenText] = useState("");
|
||||
const [manualStopCount, setManualStopCount] = useState(0);
|
||||
|
||||
// WebSocket and audio state
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const mediaSourceRef = useRef<MediaSource | null>(null);
|
||||
const sourceBufferRef = useRef<SourceBuffer | null>(null);
|
||||
const audioElementRef = useRef<HTMLAudioElement | null>(null);
|
||||
const audioUrlRef = useRef<string | null>(null);
|
||||
const pendingChunksRef = useRef<Uint8Array[]>([]);
|
||||
const isAppendingRef = useRef(false);
|
||||
const isPlayingRef = useRef(false);
|
||||
const hasStartedPlaybackRef = useRef(false);
|
||||
|
||||
// Text tracking
|
||||
const committedPositionRef = useRef(0);
|
||||
const lastRawTextRef = useRef("");
|
||||
const pendingTextRef = useRef<string[]>([]);
|
||||
const isConnectingRef = useRef(false);
|
||||
|
||||
// Timers
|
||||
const flushTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const fastStartTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const loadingTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const endCheckIntervalRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const hasSpokenFirstChunkRef = useRef(false);
|
||||
const hasSignaledEndRef = useRef(false);
|
||||
const streamEndedRef = useRef(false);
|
||||
|
||||
// Process next chunk from the pending queue
|
||||
const processNextChunk = useCallback(() => {
|
||||
if (
|
||||
isAppendingRef.current ||
|
||||
pendingChunksRef.current.length === 0 ||
|
||||
!sourceBufferRef.current ||
|
||||
sourceBufferRef.current.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = pendingChunksRef.current.shift();
|
||||
if (chunk) {
|
||||
isAppendingRef.current = true;
|
||||
try {
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
sourceBufferRef.current.appendBuffer(buffer);
|
||||
} catch {
|
||||
isAppendingRef.current = false;
|
||||
processNextChunk();
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Finalize the media stream when done
|
||||
const finalizeStream = useCallback(() => {
|
||||
if (pendingChunksRef.current.length > 0 || isAppendingRef.current) {
|
||||
setTimeout(() => finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
streamEndedRef.current = true;
|
||||
|
||||
if (
|
||||
mediaSourceRef.current &&
|
||||
mediaSourceRef.current.readyState === "open" &&
|
||||
sourceBufferRef.current &&
|
||||
!sourceBufferRef.current.updating
|
||||
) {
|
||||
try {
|
||||
mediaSourceRef.current.endOfStream();
|
||||
} catch {
|
||||
// Ignore errors when ending stream
|
||||
}
|
||||
}
|
||||
|
||||
// Clear any existing end check interval
|
||||
if (endCheckIntervalRef.current) {
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
endCheckIntervalRef.current = null;
|
||||
}
|
||||
|
||||
// More aggressive end detection: check every 200ms if audio has ended
|
||||
// This handles cases where onended event doesn't fire with MediaSource
|
||||
endCheckIntervalRef.current = setInterval(() => {
|
||||
const audioEl = audioElementRef.current;
|
||||
|
||||
// If audio element is gone or stream was reset, clean up
|
||||
if (!audioEl || !streamEndedRef.current) {
|
||||
if (endCheckIntervalRef.current) {
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
endCheckIntervalRef.current = null;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if audio has ended (either via ended property or by reaching duration)
|
||||
const hasEnded =
|
||||
audioEl.ended ||
|
||||
audioEl.paused ||
|
||||
(audioEl.duration > 0 &&
|
||||
audioEl.currentTime >= audioEl.duration - 0.1) ||
|
||||
audioEl.readyState === 0;
|
||||
|
||||
if (hasEnded && isPlayingRef.current) {
|
||||
isPlayingRef.current = false;
|
||||
setIsTTSPlaying(false);
|
||||
if (endCheckIntervalRef.current) {
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
endCheckIntervalRef.current = null;
|
||||
}
|
||||
}
|
||||
}, 200);
|
||||
|
||||
// Fallback: if audio doesn't finish playing within 10s after stream ends,
|
||||
// reset the playing state to prevent mic button from being stuck disabled
|
||||
setTimeout(() => {
|
||||
if (endCheckIntervalRef.current) {
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
endCheckIntervalRef.current = null;
|
||||
}
|
||||
if (isPlayingRef.current) {
|
||||
isPlayingRef.current = false;
|
||||
setIsTTSPlaying(false);
|
||||
}
|
||||
}, 10000);
|
||||
}, []);
|
||||
|
||||
// Initialize MediaSource for streaming audio
|
||||
const initMediaSource = useCallback(async () => {
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
mediaSourceRef.current = new MediaSource();
|
||||
audioElementRef.current = new Audio();
|
||||
audioUrlRef.current = URL.createObjectURL(mediaSourceRef.current);
|
||||
audioElementRef.current.src = audioUrlRef.current;
|
||||
|
||||
audioElementRef.current.onplay = () => {
|
||||
if (!isPlayingRef.current) {
|
||||
isPlayingRef.current = true;
|
||||
setIsTTSPlaying(true);
|
||||
}
|
||||
};
|
||||
|
||||
audioElementRef.current.onended = () => {
|
||||
isPlayingRef.current = false;
|
||||
setIsTTSPlaying(false);
|
||||
};
|
||||
|
||||
audioElementRef.current.onerror = () => {
|
||||
isPlayingRef.current = false;
|
||||
setIsTTSPlaying(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
if (!mediaSourceRef.current) {
|
||||
reject(new Error("MediaSource not initialized"));
|
||||
return;
|
||||
}
|
||||
|
||||
mediaSourceRef.current.onsourceopen = () => {
|
||||
try {
|
||||
sourceBufferRef.current =
|
||||
mediaSourceRef.current!.addSourceBuffer("audio/mpeg");
|
||||
sourceBufferRef.current.mode = "sequence";
|
||||
|
||||
sourceBufferRef.current.onupdateend = () => {
|
||||
isAppendingRef.current = false;
|
||||
processNextChunk();
|
||||
};
|
||||
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
|
||||
mediaSourceRef.current.onsourceclose = () => {
|
||||
if (mediaSourceRef.current?.readyState === "closed") {
|
||||
reject(new Error("MediaSource closed unexpectedly"));
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
return true;
|
||||
}, [processNextChunk]);
|
||||
|
||||
// Handle incoming audio data from WebSocket
|
||||
const handleAudioData = useCallback(
|
||||
async (data: ArrayBuffer) => {
|
||||
pendingChunksRef.current.push(new Uint8Array(data));
|
||||
processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (!hasStartedPlaybackRef.current && audioElementRef.current) {
|
||||
// Small delay to buffer a bit before starting
|
||||
setTimeout(() => {
|
||||
const audioEl = audioElementRef.current;
|
||||
if (!audioEl || hasStartedPlaybackRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
audioEl
|
||||
.play()
|
||||
.then(() => {
|
||||
hasStartedPlaybackRef.current = true;
|
||||
})
|
||||
.catch(() => {
|
||||
// Keep hasStartedPlaybackRef as false so we retry on next audio chunk.
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
},
|
||||
[processNextChunk]
|
||||
);
|
||||
|
||||
// Get WebSocket URL for TTS with authentication token
|
||||
const getWebSocketUrl = useCallback(async () => {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/synthesize/stream"
|
||||
: "/api/voice/synthesize/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}, []);
|
||||
|
||||
// Connect to WebSocket TTS
|
||||
const connectWebSocket = useCallback(async () => {
|
||||
// Skip if already connected, connecting, or in the process of connecting
|
||||
if (
|
||||
wsRef.current?.readyState === WebSocket.OPEN ||
|
||||
wsRef.current?.readyState === WebSocket.CONNECTING ||
|
||||
isConnectingRef.current
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Set connecting flag to prevent concurrent connection attempts
|
||||
isConnectingRef.current = true;
|
||||
|
||||
try {
|
||||
// Initialize MediaSource first
|
||||
const initialized = await initMediaSource();
|
||||
if (!initialized) {
|
||||
isConnectingRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Get WebSocket URL with auth token
|
||||
const wsUrl = await getWebSocketUrl();
|
||||
|
||||
const ws = new WebSocket(wsUrl);
|
||||
|
||||
ws.onopen = () => {
|
||||
isConnectingRef.current = false;
|
||||
// Send initial config
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
type: "config",
|
||||
voice: preferredVoice,
|
||||
speed: playbackSpeed,
|
||||
})
|
||||
);
|
||||
|
||||
// Send any pending text
|
||||
for (const text of pendingTextRef.current) {
|
||||
ws.send(JSON.stringify({ type: "synthesize", text }));
|
||||
}
|
||||
pendingTextRef.current = [];
|
||||
};
|
||||
|
||||
ws.onmessage = async (event) => {
|
||||
if (event.data instanceof Blob) {
|
||||
const arrayBuffer = await event.data.arrayBuffer();
|
||||
handleAudioData(arrayBuffer);
|
||||
} else if (typeof event.data === "string") {
|
||||
try {
|
||||
const msg = JSON.parse(event.data);
|
||||
if (msg.type === "audio_done") {
|
||||
if (loadingTimeoutRef.current) {
|
||||
clearTimeout(loadingTimeoutRef.current);
|
||||
loadingTimeoutRef.current = null;
|
||||
}
|
||||
setIsTTSLoading(false);
|
||||
finalizeStream();
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = () => {
|
||||
isConnectingRef.current = false;
|
||||
setIsTTSLoading(false);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
wsRef.current = null;
|
||||
isConnectingRef.current = false;
|
||||
setIsTTSLoading(false);
|
||||
finalizeStream();
|
||||
};
|
||||
|
||||
wsRef.current = ws;
|
||||
} catch {
|
||||
isConnectingRef.current = false;
|
||||
}
|
||||
}, [
|
||||
preferredVoice,
|
||||
playbackSpeed,
|
||||
handleAudioData,
|
||||
getWebSocketUrl,
|
||||
initMediaSource,
|
||||
finalizeStream,
|
||||
]);
|
||||
|
||||
// Send text to TTS via WebSocket
|
||||
const sendTextToTTS = useCallback(
|
||||
(text: string) => {
|
||||
if (!text.trim()) return;
|
||||
|
||||
setIsTTSLoading(true);
|
||||
setSpokenText((prev) => (prev ? prev + " " + text : text));
|
||||
|
||||
// Set a timeout to reset loading state if TTS doesn't complete
|
||||
if (loadingTimeoutRef.current) {
|
||||
clearTimeout(loadingTimeoutRef.current);
|
||||
}
|
||||
loadingTimeoutRef.current = setTimeout(() => {
|
||||
setIsTTSLoading(false);
|
||||
setIsTTSPlaying(false);
|
||||
}, 60000);
|
||||
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify({ type: "synthesize", text }));
|
||||
} else {
|
||||
pendingTextRef.current.push(text);
|
||||
connectWebSocket();
|
||||
}
|
||||
},
|
||||
[connectWebSocket]
|
||||
);
|
||||
|
||||
const streamTTS = useCallback(
|
||||
(text: string, isComplete: boolean = false) => {
|
||||
if (!autoPlayback) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip if text hasn't changed
|
||||
if (text === lastRawTextRef.current && !isComplete) return;
|
||||
lastRawTextRef.current = text;
|
||||
|
||||
// Clear pending timers
|
||||
if (flushTimerRef.current) {
|
||||
clearTimeout(flushTimerRef.current);
|
||||
flushTimerRef.current = null;
|
||||
}
|
||||
if (fastStartTimerRef.current) {
|
||||
clearTimeout(fastStartTimerRef.current);
|
||||
fastStartTimerRef.current = null;
|
||||
}
|
||||
|
||||
// Clean the full text
|
||||
const cleanedText = cleanTextForTTS(text);
|
||||
const uncommittedText = cleanedText.slice(committedPositionRef.current);
|
||||
|
||||
// On completion, we must still signal "end" even if there's no new text.
|
||||
// Otherwise ElevenLabs waits for more input and eventually times out.
|
||||
if (uncommittedText.length === 0) {
|
||||
if (isComplete && !hasSignaledEndRef.current) {
|
||||
hasSignaledEndRef.current = true;
|
||||
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify({ type: "end" }));
|
||||
} else {
|
||||
const sendEnd = () => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
if (pendingTextRef.current.length === 0) {
|
||||
wsRef.current.send(JSON.stringify({ type: "end" }));
|
||||
} else {
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
};
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Find chunk boundaries and send immediately
|
||||
let remaining = uncommittedText;
|
||||
let offset = 0;
|
||||
|
||||
while (remaining.length > 0) {
|
||||
const boundaryIndex = findChunkBoundary(remaining);
|
||||
|
||||
if (boundaryIndex > 0) {
|
||||
const chunkText = remaining.slice(0, boundaryIndex).trim();
|
||||
if (chunkText.length > 0) {
|
||||
sendTextToTTS(chunkText);
|
||||
hasSpokenFirstChunkRef.current = true;
|
||||
}
|
||||
offset += boundaryIndex;
|
||||
remaining = remaining.slice(boundaryIndex).trim();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
committedPositionRef.current += offset;
|
||||
|
||||
// Handle remaining text when stream is complete
|
||||
if (isComplete && remaining.trim().length > 0) {
|
||||
sendTextToTTS(remaining.trim());
|
||||
committedPositionRef.current = cleanedText.length;
|
||||
hasSpokenFirstChunkRef.current = true;
|
||||
}
|
||||
|
||||
// When streaming is complete, signal end to flush remaining audio
|
||||
if (isComplete && !hasSignaledEndRef.current) {
|
||||
hasSignaledEndRef.current = true;
|
||||
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify({ type: "end" }));
|
||||
} else {
|
||||
const sendEnd = () => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
if (pendingTextRef.current.length === 0) {
|
||||
wsRef.current.send(JSON.stringify({ type: "end" }));
|
||||
} else {
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
} else if (wsRef.current?.readyState === WebSocket.CONNECTING) {
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
};
|
||||
setTimeout(sendEnd, 100);
|
||||
}
|
||||
}
|
||||
|
||||
const currentUncommitted = cleanedText
|
||||
.slice(committedPositionRef.current)
|
||||
.trim();
|
||||
|
||||
// Fast start: if we haven't spoken yet and have 20+ chars, send after 200ms
|
||||
if (
|
||||
!hasSpokenFirstChunkRef.current &&
|
||||
currentUncommitted.length >= 20 &&
|
||||
!isComplete
|
||||
) {
|
||||
fastStartTimerRef.current = setTimeout(() => {
|
||||
if (hasSpokenFirstChunkRef.current) return;
|
||||
|
||||
const nowCleaned = cleanTextForTTS(lastRawTextRef.current);
|
||||
const nowUncommitted = nowCleaned
|
||||
.slice(committedPositionRef.current)
|
||||
.trim();
|
||||
|
||||
if (nowUncommitted.length >= 20) {
|
||||
// Find a reasonable break point
|
||||
let breakPoint = nowUncommitted.length;
|
||||
const spaceIdx = nowUncommitted.lastIndexOf(" ", 50);
|
||||
if (spaceIdx >= 15) breakPoint = spaceIdx;
|
||||
|
||||
const chunk = nowUncommitted.slice(0, breakPoint).trim();
|
||||
if (chunk.length > 0) {
|
||||
sendTextToTTS(chunk);
|
||||
committedPositionRef.current += breakPoint;
|
||||
hasSpokenFirstChunkRef.current = true;
|
||||
}
|
||||
}
|
||||
}, 200);
|
||||
}
|
||||
|
||||
// Flush timer for text ending with punctuation
|
||||
if (
|
||||
currentUncommitted.length > 0 &&
|
||||
!isComplete &&
|
||||
/[.!?]$/.test(currentUncommitted)
|
||||
) {
|
||||
flushTimerRef.current = setTimeout(() => {
|
||||
const nowCleaned = cleanTextForTTS(lastRawTextRef.current);
|
||||
const nowUncommitted = nowCleaned
|
||||
.slice(committedPositionRef.current)
|
||||
.trim();
|
||||
|
||||
if (nowUncommitted.length > 0) {
|
||||
sendTextToTTS(nowUncommitted);
|
||||
committedPositionRef.current = nowCleaned.length;
|
||||
hasSpokenFirstChunkRef.current = true;
|
||||
}
|
||||
}, 250);
|
||||
}
|
||||
},
|
||||
[autoPlayback, sendTextToTTS]
|
||||
);
|
||||
|
||||
const stopTTS = useCallback((options?: { manual?: boolean }) => {
|
||||
// Clear timers
|
||||
if (flushTimerRef.current) {
|
||||
clearTimeout(flushTimerRef.current);
|
||||
flushTimerRef.current = null;
|
||||
}
|
||||
if (fastStartTimerRef.current) {
|
||||
clearTimeout(fastStartTimerRef.current);
|
||||
fastStartTimerRef.current = null;
|
||||
}
|
||||
if (loadingTimeoutRef.current) {
|
||||
clearTimeout(loadingTimeoutRef.current);
|
||||
loadingTimeoutRef.current = null;
|
||||
}
|
||||
if (endCheckIntervalRef.current) {
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
endCheckIntervalRef.current = null;
|
||||
}
|
||||
|
||||
// Revoke blob URL to prevent memory leak
|
||||
if (audioUrlRef.current) {
|
||||
URL.revokeObjectURL(audioUrlRef.current);
|
||||
audioUrlRef.current = null;
|
||||
}
|
||||
|
||||
// Stop audio element
|
||||
if (audioElementRef.current) {
|
||||
audioElementRef.current.pause();
|
||||
audioElementRef.current.src = "";
|
||||
audioElementRef.current = null;
|
||||
}
|
||||
|
||||
// Cleanup MediaSource
|
||||
if (
|
||||
mediaSourceRef.current &&
|
||||
mediaSourceRef.current.readyState === "open"
|
||||
) {
|
||||
try {
|
||||
if (sourceBufferRef.current) {
|
||||
mediaSourceRef.current.removeSourceBuffer(sourceBufferRef.current);
|
||||
}
|
||||
mediaSourceRef.current.endOfStream();
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
|
||||
mediaSourceRef.current = null;
|
||||
sourceBufferRef.current = null;
|
||||
pendingChunksRef.current = [];
|
||||
isAppendingRef.current = false;
|
||||
hasStartedPlaybackRef.current = false;
|
||||
pendingTextRef.current = [];
|
||||
isPlayingRef.current = false;
|
||||
hasSignaledEndRef.current = false;
|
||||
isConnectingRef.current = false;
|
||||
streamEndedRef.current = false;
|
||||
|
||||
// Close WebSocket
|
||||
if (wsRef.current) {
|
||||
try {
|
||||
wsRef.current.send(JSON.stringify({ type: "end" }));
|
||||
wsRef.current.close();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
wsRef.current = null;
|
||||
}
|
||||
|
||||
setIsTTSPlaying(false);
|
||||
setIsTTSLoading(false);
|
||||
if (options?.manual) {
|
||||
setManualStopCount((count) => count + 1);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const resetTTS = useCallback(() => {
|
||||
stopTTS();
|
||||
committedPositionRef.current = 0;
|
||||
lastRawTextRef.current = "";
|
||||
hasSpokenFirstChunkRef.current = false;
|
||||
hasSignaledEndRef.current = false;
|
||||
setSpokenText("");
|
||||
}, [stopTTS]);
|
||||
|
||||
// Reset TTS state when voice auto-playback is disabled
|
||||
// This prevents the mic button from being stuck disabled
|
||||
const prevAutoPlaybackRef = useRef(autoPlayback);
|
||||
useEffect(() => {
|
||||
if (prevAutoPlaybackRef.current && !autoPlayback) {
|
||||
// Auto-playback was just disabled, clean up TTS state
|
||||
resetTTS();
|
||||
}
|
||||
prevAutoPlaybackRef.current = autoPlayback;
|
||||
}, [autoPlayback, resetTTS]);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (flushTimerRef.current) clearTimeout(flushTimerRef.current);
|
||||
if (fastStartTimerRef.current) clearTimeout(fastStartTimerRef.current);
|
||||
if (loadingTimeoutRef.current) clearTimeout(loadingTimeoutRef.current);
|
||||
if (endCheckIntervalRef.current)
|
||||
clearInterval(endCheckIntervalRef.current);
|
||||
if (audioUrlRef.current) {
|
||||
URL.revokeObjectURL(audioUrlRef.current);
|
||||
}
|
||||
if (wsRef.current) {
|
||||
try {
|
||||
wsRef.current.close();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
if (audioElementRef.current) {
|
||||
try {
|
||||
audioElementRef.current.pause();
|
||||
audioElementRef.current.src = "";
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
if (
|
||||
mediaSourceRef.current &&
|
||||
mediaSourceRef.current.readyState === "open"
|
||||
) {
|
||||
try {
|
||||
mediaSourceRef.current.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<VoiceModeContext.Provider
|
||||
value={{
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
spokenText,
|
||||
streamTTS,
|
||||
stopTTS,
|
||||
manualStopCount,
|
||||
resetTTS,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</VoiceModeContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useVoiceMode(): VoiceModeContextType {
|
||||
const context = useContext(VoiceModeContext);
|
||||
if (!context) {
|
||||
throw new Error("useVoiceMode must be used within VoiceModeProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
@@ -129,6 +129,7 @@ const InputComboBox = ({
|
||||
leftSearchIcon = false,
|
||||
rightSection,
|
||||
separatorLabel = "Other options",
|
||||
showAddPrefix = false,
|
||||
showOtherOptions = false,
|
||||
...rest
|
||||
}: WithoutStyles<InputComboBoxProps>) => {
|
||||
@@ -156,14 +157,11 @@ const InputComboBox = ({
|
||||
const visibleUnmatchedOptions =
|
||||
hasSearchTerm && showOtherOptions ? unmatchedOptions : [];
|
||||
|
||||
// Whether to show the create option (only when no partial matches)
|
||||
const showCreateOption =
|
||||
!strict &&
|
||||
hasSearchTerm &&
|
||||
inputValue.trim() !== "" &&
|
||||
matchedOptions.length === 0;
|
||||
// Whether to show the create option (always show when typing in non-strict mode)
|
||||
const showCreateOption = !strict && hasSearchTerm && inputValue.trim() !== "";
|
||||
|
||||
// Combined list for keyboard navigation (includes create option when shown)
|
||||
// Only show matched options when searching (hide unmatched)
|
||||
const allVisibleOptions = useMemo(() => {
|
||||
const baseOptions = [...matchedOptions, ...visibleUnmatchedOptions];
|
||||
if (showCreateOption) {
|
||||
@@ -440,6 +438,7 @@ const InputComboBox = ({
|
||||
inputValue={inputValue}
|
||||
allowCreate={!strict}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</>
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ interface ComboBoxDropdownProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -58,6 +60,7 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
@@ -132,6 +135,7 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue={inputValue}
|
||||
allowCreate={allowCreate}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</div>,
|
||||
document.body
|
||||
|
||||
@@ -24,6 +24,8 @@ interface OptionsListProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -45,6 +47,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
}) => {
|
||||
// Index offset for other options when create option is shown
|
||||
const indexOffset = showCreateOption ? 1 : 0;
|
||||
@@ -70,7 +73,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
data-index={0}
|
||||
role="option"
|
||||
aria-selected={false}
|
||||
aria-label={`Create "${inputValue}"`}
|
||||
aria-label={`${showAddPrefix ? "Add" : "Create"} "${inputValue}"`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onSelect({ value: inputValue, label: inputValue });
|
||||
@@ -81,19 +84,48 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
onMouseEnter={() => onMouseEnter(0)}
|
||||
onMouseMove={onMouseMove}
|
||||
className={cn(
|
||||
"px-3 py-2 cursor-pointer transition-colors",
|
||||
"cursor-pointer transition-colors",
|
||||
"flex items-center justify-between rounded-08",
|
||||
highlightedIndex === 0 && "bg-background-tint-02",
|
||||
"hover:bg-background-tint-02"
|
||||
"hover:bg-background-tint-02",
|
||||
showAddPrefix ? "px-1.5 py-1.5" : "px-3 py-2"
|
||||
)}
|
||||
>
|
||||
<span className="font-main-ui-action text-text-04 truncate min-w-0">
|
||||
{inputValue}
|
||||
<span
|
||||
className={cn(
|
||||
"font-main-ui-action truncate min-w-0",
|
||||
showAddPrefix ? "px-1" : ""
|
||||
)}
|
||||
>
|
||||
{showAddPrefix ? (
|
||||
<>
|
||||
<span className="text-text-03">Add</span>
|
||||
<span className="text-text-04">{` ${inputValue}`}</span>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-text-04">{inputValue}</span>
|
||||
)}
|
||||
</span>
|
||||
<SvgPlus className="w-4 h-4 text-text-03 flex-shrink-0 ml-2" />
|
||||
<SvgPlus
|
||||
className={cn(
|
||||
"w-4 h-4 flex-shrink-0",
|
||||
showAddPrefix ? "text-text-04 mx-1" : "text-text-03 ml-2"
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Separator - show when there are options to display */}
|
||||
{separatorLabel &&
|
||||
(matchedOptions.length > 0 ||
|
||||
(!hasSearchTerm && unmatchedOptions.length > 0)) && (
|
||||
<div className="px-3 py-1">
|
||||
<Text as="p" text03 secondaryBody>
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Matched/Filtered Options */}
|
||||
{matchedOptions.map((option, idx) => {
|
||||
const globalIndex = idx + indexOffset;
|
||||
@@ -116,37 +148,27 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Separator - only show if there are unmatched options and a search term */}
|
||||
{hasSearchTerm && unmatchedOptions.length > 0 && (
|
||||
<div className="px-3 py-2 pt-3">
|
||||
<div className="border-t border-border-01 pt-2">
|
||||
<Text as="p" text04 secondaryBody className="text-text-02">
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Unmatched Options */}
|
||||
{unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{/* Unmatched Options - only show when NOT searching */}
|
||||
{!hasSearchTerm &&
|
||||
unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -40,6 +40,8 @@ export interface InputComboBoxProps
|
||||
rightSection?: React.ReactNode;
|
||||
/** Label for the separator between matched and unmatched options */
|
||||
separatorLabel?: string;
|
||||
/** Show "Add" prefix in create option (e.g., "Add [value]") */
|
||||
showAddPrefix?: boolean;
|
||||
/**
|
||||
* When true, keep non-matching options visible under a separator while searching.
|
||||
* Defaults to false so search results are strictly filtered.
|
||||
|
||||
@@ -748,6 +748,7 @@ function ChatPreferencesSettings() {
|
||||
updateUserShortcuts,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
} = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
@@ -933,6 +934,64 @@ function ChatPreferencesSettings() {
|
||||
{user?.preferences?.shortcut_enabled && <PromptShortcuts />}
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Voice"
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void updateUserVoiceSettings({ auto_send: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Playback"
|
||||
description="Automatically play voice responses."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_playback ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void updateUserVoiceSettings({ auto_playback: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Playback Speed"
|
||||
description="Adjust the speed of voice playback."
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<input
|
||||
type="range"
|
||||
min="0.5"
|
||||
max="2"
|
||||
step="0.1"
|
||||
value={user?.preferences.voice_playback_speed ?? 1}
|
||||
onChange={(e) => {
|
||||
void updateUserVoiceSettings({
|
||||
playback_speed: parseFloat(e.target.value),
|
||||
});
|
||||
}}
|
||||
className="w-24 h-2 rounded-lg appearance-none cursor-pointer bg-background-neutral-02"
|
||||
/>
|
||||
<span className="text-sm text-text-02 w-10">
|
||||
{(user?.preferences.voice_playback_speed ?? 1).toFixed(1)}x
|
||||
</span>
|
||||
</div>
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
</Section>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -56,6 +56,9 @@ import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import MicrophoneButton from "@/sections/input/MicrophoneButton";
|
||||
import RecordingWaveform from "@/sections/input/RecordingWaveform";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
const MIN_INPUT_HEIGHT = 44;
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
@@ -114,6 +117,12 @@ const AppInputBar = React.memo(
|
||||
}: AppInputBarProps) => {
|
||||
// Internal message state - kept local to avoid parent re-renders on every keystroke
|
||||
const [message, setMessage] = useState(initialMessage);
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isMuted, setIsMuted] = useState(false);
|
||||
const stopRecordingRef = useRef<(() => Promise<string | null>) | null>(
|
||||
null
|
||||
);
|
||||
const setMutedRef = useRef<((muted: boolean) => void) | null>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesWrapperRef = useRef<HTMLDivElement>(null);
|
||||
@@ -121,6 +130,16 @@ const AppInputBar = React.memo(
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const { user } = useUser();
|
||||
const { isClassifying, classification } = useQueryController();
|
||||
const { stopTTS } = useVoiceMode();
|
||||
|
||||
// Wrapper for onSubmit that stops TTS first to prevent overlapping voices
|
||||
const handleSubmit = useCallback(
|
||||
(text: string) => {
|
||||
stopTTS();
|
||||
onSubmit(text);
|
||||
},
|
||||
[stopTTS, onSubmit]
|
||||
);
|
||||
|
||||
// Expose reset and focus methods to parent via ref
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
@@ -543,6 +562,21 @@ const AppInputBar = React.memo(
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
<MicrophoneButton
|
||||
onTranscription={(text) => setMessage(text)}
|
||||
disabled={disabled || chatState === "streaming"}
|
||||
autoSend={user?.preferences?.voice_auto_send ?? false}
|
||||
autoListen={user?.preferences?.voice_auto_playback ?? false}
|
||||
chatState={chatState}
|
||||
onRecordingChange={setIsRecording}
|
||||
stopRecordingRef={stopRecordingRef}
|
||||
onRecordingStart={() => setMessage("")}
|
||||
onAutoSend={(text) => {
|
||||
handleSubmit(text);
|
||||
}}
|
||||
onMuteChange={setIsMuted}
|
||||
setMutedRef={setMutedRef}
|
||||
/>
|
||||
<Button
|
||||
id="onyx-chat-input-send-button"
|
||||
icon={
|
||||
@@ -575,7 +609,7 @@ const AppInputBar = React.memo(
|
||||
ref={containerRef}
|
||||
id="onyx-chat-input"
|
||||
className={cn(
|
||||
"w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16"
|
||||
"w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16 relative"
|
||||
// # Note (from @raunakab):
|
||||
//
|
||||
// `shadow-01` extends ~14px below the element (2px offset + 12px blur).
|
||||
@@ -639,9 +673,11 @@ const AppInputBar = React.memo(
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
aria-multiline={true}
|
||||
placeholder={
|
||||
isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
isRecording
|
||||
? "Listening..."
|
||||
: isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
@@ -658,7 +694,7 @@ const AppInputBar = React.memo(
|
||||
!isClassifying &&
|
||||
!hasUploadingFiles
|
||||
) {
|
||||
onSubmit(message);
|
||||
handleSubmit(message);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -722,7 +758,7 @@ const AppInputBar = React.memo(
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
onSubmit(message);
|
||||
handleSubmit(message);
|
||||
}
|
||||
}}
|
||||
prominence="tertiary"
|
||||
@@ -733,6 +769,28 @@ const AppInputBar = React.memo(
|
||||
</div>
|
||||
|
||||
{chatControls}
|
||||
|
||||
{/* Recording waveform - position depends on session state:
|
||||
- Fresh chat (input centered): bar floats BELOW input
|
||||
- Subsequent turns (input at bottom): bar floats ABOVE input */}
|
||||
{isRecording && (
|
||||
<div
|
||||
className={cn(
|
||||
"absolute left-0 right-0 px-1.5",
|
||||
appFocus.isNewSession()
|
||||
? "-bottom-[46px]" // Fresh chat: below input
|
||||
: "-top-[46px]" // Subsequent turns: above input
|
||||
)}
|
||||
>
|
||||
<RecordingWaveform
|
||||
isRecording={isRecording}
|
||||
isMuted={isMuted}
|
||||
onMuteToggle={() => {
|
||||
setMutedRef.current?.(!isMuted);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
|
||||
224
web/src/sections/input/MicrophoneButton.tsx
Normal file
224
web/src/sections/input/MicrophoneButton.tsx
Normal file
@@ -0,0 +1,224 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgMicrophone } from "@opal/icons";
|
||||
import { useVoiceRecorder } from "@/hooks/useVoiceRecorder";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
|
||||
interface MicrophoneButtonProps {
|
||||
onTranscription: (text: string) => void;
|
||||
disabled?: boolean;
|
||||
autoSend?: boolean;
|
||||
/** Called with transcribed text when autoSend is enabled */
|
||||
onAutoSend?: (text: string) => void;
|
||||
/**
|
||||
* Internal prop: auto-start listening when TTS finishes or chat response completes.
|
||||
* Tied to voice_auto_playback user preference.
|
||||
* Enables conversation flow: speak → AI responds → auto-listen again.
|
||||
* Note: autoSend is separate - it controls whether message auto-submits after recording.
|
||||
*/
|
||||
autoListen?: boolean;
|
||||
/** Current chat state - used to detect when response streaming finishes */
|
||||
chatState?: ChatState;
|
||||
/** Called when recording state changes */
|
||||
onRecordingChange?: (isRecording: boolean) => void;
|
||||
/** Ref to expose stop recording function to parent */
|
||||
stopRecordingRef?: React.MutableRefObject<
|
||||
(() => Promise<string | null>) | null
|
||||
>;
|
||||
/** Called when recording starts to clear input */
|
||||
onRecordingStart?: () => void;
|
||||
/** Called when mute state changes */
|
||||
onMuteChange?: (isMuted: boolean) => void;
|
||||
/** Ref to expose setMuted function to parent */
|
||||
setMutedRef?: React.MutableRefObject<((muted: boolean) => void) | null>;
|
||||
}
|
||||
|
||||
function MicrophoneButton({
|
||||
onTranscription,
|
||||
disabled = false,
|
||||
autoSend = false,
|
||||
onAutoSend,
|
||||
autoListen = false,
|
||||
chatState,
|
||||
onRecordingChange,
|
||||
stopRecordingRef,
|
||||
onRecordingStart,
|
||||
onMuteChange,
|
||||
setMutedRef,
|
||||
}: MicrophoneButtonProps) {
|
||||
const { isTTSPlaying, isTTSLoading, manualStopCount } = useVoiceMode();
|
||||
|
||||
// Refs for tracking state across renders
|
||||
const wasTTSPlayingRef = useRef(false);
|
||||
const manualStopRequestedRef = useRef(false);
|
||||
const lastHandledManualStopCountRef = useRef(manualStopCount);
|
||||
const autoListenCooldownTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
|
||||
// Handler for VAD-triggered auto-send (when server detects silence)
|
||||
const handleFinalTranscript = useCallback(
|
||||
(text: string) => {
|
||||
onTranscription(text);
|
||||
const isManualStop = manualStopRequestedRef.current;
|
||||
// Only auto-send if chat is ready for input (not streaming)
|
||||
if (!isManualStop && autoSend && onAutoSend && chatState === "input") {
|
||||
onAutoSend(text);
|
||||
}
|
||||
},
|
||||
[onTranscription, autoSend, onAutoSend, chatState]
|
||||
);
|
||||
|
||||
const {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
if (stopRecordingRef) {
|
||||
stopRecordingRef.current = stopRecording;
|
||||
}
|
||||
}, [stopRecording, stopRecordingRef]);
|
||||
|
||||
// Expose setMuted to parent
|
||||
useEffect(() => {
|
||||
if (setMutedRef) {
|
||||
setMutedRef.current = setMuted;
|
||||
}
|
||||
}, [setMuted, setMutedRef]);
|
||||
|
||||
// Notify parent when mute state changes
|
||||
useEffect(() => {
|
||||
onMuteChange?.(isMuted);
|
||||
}, [isMuted, onMuteChange]);
|
||||
|
||||
// Notify parent when recording state changes
|
||||
useEffect(() => {
|
||||
onRecordingChange?.(isRecording);
|
||||
}, [isRecording, onRecordingChange]);
|
||||
|
||||
// Update input with live transcript as user speaks
|
||||
useEffect(() => {
|
||||
if (isRecording && liveTranscript) {
|
||||
onTranscription(liveTranscript);
|
||||
}
|
||||
}, [isRecording, liveTranscript, onTranscription]);
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isRecording) {
|
||||
// When recording, clicking the mic button stops recording
|
||||
manualStopRequestedRef.current = true;
|
||||
try {
|
||||
await stopRecording();
|
||||
} finally {
|
||||
manualStopRequestedRef.current = false;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
// Clear input before starting recording
|
||||
onRecordingStart?.();
|
||||
await startRecording();
|
||||
} catch {
|
||||
toast.error("Could not access microphone");
|
||||
}
|
||||
}
|
||||
}, [isRecording, startRecording, stopRecording, onRecordingStart]);
|
||||
|
||||
// Auto-start listening shortly after TTS finishes (only if autoListen is enabled).
|
||||
// Small cooldown reduces playback bleed being re-captured by the microphone.
|
||||
useEffect(() => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
|
||||
const stoppedManually =
|
||||
manualStopCount !== lastHandledManualStopCountRef.current;
|
||||
|
||||
if (
|
||||
wasTTSPlayingRef.current &&
|
||||
!isTTSPlaying &&
|
||||
!isTTSLoading &&
|
||||
autoListen &&
|
||||
!disabled &&
|
||||
!isRecording &&
|
||||
!stoppedManually
|
||||
) {
|
||||
autoListenCooldownTimerRef.current = setTimeout(() => {
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
if (
|
||||
!autoListen ||
|
||||
disabled ||
|
||||
isRecording ||
|
||||
isTTSPlaying ||
|
||||
isTTSLoading
|
||||
) {
|
||||
return;
|
||||
}
|
||||
startRecording().catch(() => {
|
||||
// Silently ignore auto-start failures
|
||||
});
|
||||
}, 400);
|
||||
}
|
||||
|
||||
if (stoppedManually) {
|
||||
lastHandledManualStopCountRef.current = manualStopCount;
|
||||
}
|
||||
|
||||
wasTTSPlayingRef.current = isTTSPlaying || isTTSLoading;
|
||||
}, [
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
autoListen,
|
||||
disabled,
|
||||
isRecording,
|
||||
startRecording,
|
||||
manualStopCount,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
// Icon: show loader when processing, otherwise mic
|
||||
const icon = isProcessing ? SimpleLoader : SvgMicrophone;
|
||||
|
||||
// Disable when processing or TTS is playing (don't want to pick up TTS audio)
|
||||
const isDisabled = disabled || isProcessing || isTTSPlaying || isTTSLoading;
|
||||
|
||||
// Recording = darkened (primary), not recording = light (tertiary)
|
||||
const prominence = isRecording ? "primary" : "tertiary";
|
||||
|
||||
return (
|
||||
<Button
|
||||
icon={icon}
|
||||
disabled={isDisabled}
|
||||
onClick={handleClick}
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
prominence={prominence}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default MicrophoneButton;
|
||||
97
web/src/sections/input/RecordingWaveform.tsx
Normal file
97
web/src/sections/input/RecordingWaveform.tsx
Normal file
@@ -0,0 +1,97 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useMemo } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@opal/components";
|
||||
import { SvgMicrophone, SvgMicrophoneOff } from "@opal/icons";
|
||||
|
||||
interface RecordingWaveformProps {
|
||||
isRecording: boolean;
|
||||
isMuted?: boolean;
|
||||
onMuteToggle?: () => void;
|
||||
}
|
||||
|
||||
function RecordingWaveform({
|
||||
isRecording,
|
||||
isMuted = false,
|
||||
onMuteToggle,
|
||||
}: RecordingWaveformProps) {
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
|
||||
// Reset and start timer when recording starts
|
||||
useEffect(() => {
|
||||
if (!isRecording) {
|
||||
setElapsedSeconds(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
setElapsedSeconds((prev) => prev + 1);
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [isRecording]);
|
||||
|
||||
// Format time as MM:SS
|
||||
const formattedTime = useMemo(() => {
|
||||
const minutes = Math.floor(elapsedSeconds / 60);
|
||||
const seconds = elapsedSeconds % 60;
|
||||
return `${minutes.toString().padStart(2, "0")}:${seconds
|
||||
.toString()
|
||||
.padStart(2, "0")}`;
|
||||
}, [elapsedSeconds]);
|
||||
|
||||
// Generate random bar heights for waveform animation
|
||||
const bars = useMemo(() => {
|
||||
return Array.from({ length: 50 }, (_, i) => ({
|
||||
id: i,
|
||||
// Create a wave pattern with some randomness
|
||||
baseHeight: Math.sin(i * 0.3) * 6 + 8,
|
||||
delay: i * 0.02,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
if (!isRecording) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3 px-3 py-2 bg-background-tint-00 rounded-12 min-h-[32px]">
|
||||
{/* Waveform visualization */}
|
||||
<div className="flex-1 flex items-center justify-center gap-[2px] h-4 overflow-hidden">
|
||||
{bars.map((bar) => (
|
||||
<div
|
||||
key={bar.id}
|
||||
className={cn(
|
||||
"w-[2px] bg-text-03 rounded-full",
|
||||
!isMuted && "animate-waveform"
|
||||
)}
|
||||
style={{
|
||||
// When muted, show flat bars (2px height), otherwise animate with base height
|
||||
height: isMuted ? "2px" : `${bar.baseHeight}px`,
|
||||
animationDelay: isMuted ? undefined : `${bar.delay}s`,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Timer */}
|
||||
<span className="font-mono text-xs text-text-03 tabular-nums shrink-0">
|
||||
{formattedTime}
|
||||
</span>
|
||||
|
||||
{/* Mute button */}
|
||||
{onMuteToggle && (
|
||||
<Button
|
||||
icon={isMuted ? SvgMicrophoneOff : SvgMicrophone}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
aria-label={isMuted ? "Unmute microphone" : "Mute microphone"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default RecordingWaveform;
|
||||
@@ -19,7 +19,35 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE
|
||||
import { CombinedSettings } from "@/interfaces/settings";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import SidebarBody from "@/sections/sidebar/SidebarBody";
|
||||
import { SvgArrowUpCircle } from "@opal/icons";
|
||||
import {
|
||||
SvgActions,
|
||||
SvgActivity,
|
||||
SvgArrowUpCircle,
|
||||
SvgBarChart,
|
||||
SvgCpu,
|
||||
SvgFileText,
|
||||
SvgFolder,
|
||||
SvgGlobe,
|
||||
SvgArrowExchange,
|
||||
SvgImage,
|
||||
SvgKey,
|
||||
SvgOnyxLogo,
|
||||
SvgOnyxOctagon,
|
||||
SvgSearch,
|
||||
SvgServer,
|
||||
SvgSettings,
|
||||
SvgShield,
|
||||
SvgThumbsUp,
|
||||
SvgUploadCloud,
|
||||
SvgUser,
|
||||
SvgUsers,
|
||||
SvgZoomIn,
|
||||
SvgPaintBrush,
|
||||
SvgDiscordMono,
|
||||
SvgWallet,
|
||||
SvgAudio,
|
||||
} from "@opal/icons";
|
||||
import SvgMcp from "@opal/icons/mcp";
|
||||
import { ADMIN_PATHS, sidebarItem } from "@/lib/admin-routes";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
|
||||
@@ -105,6 +133,11 @@ const collections = (
|
||||
sidebarItem(ADMIN_PATHS.LLM_MODELS),
|
||||
sidebarItem(ADMIN_PATHS.WEB_SEARCH),
|
||||
sidebarItem(ADMIN_PATHS.IMAGE_GENERATION),
|
||||
{
|
||||
name: "Voice",
|
||||
icon: SvgAudio,
|
||||
link: "/admin/configuration/voice",
|
||||
},
|
||||
sidebarItem(ADMIN_PATHS.CODE_INTERPRETER),
|
||||
...(!enableCloud && vectorDbEnabled
|
||||
? [
|
||||
@@ -131,11 +164,7 @@ const collections = (
|
||||
? [
|
||||
{
|
||||
name: "Permissions",
|
||||
items: [
|
||||
// TODO: Uncomment once Users v2 page is complete
|
||||
// sidebarItem(ADMIN_PATHS.USERS_V2),
|
||||
sidebarItem(ADMIN_PATHS.SCIM),
|
||||
],
|
||||
items: [sidebarItem(ADMIN_PATHS.SCIM)],
|
||||
},
|
||||
]
|
||||
: []),
|
||||
|
||||
Reference in New Issue
Block a user