mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-12 11:12:40 +00:00
Compare commits
44 Commits
nikg/fix-s
...
voice-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc5bba094c | ||
|
|
a9028e5ae8 | ||
|
|
4fdbef4185 | ||
|
|
9cac41bb6b | ||
|
|
1fce5b6bf5 | ||
|
|
6b035b5908 | ||
|
|
764394d5cf | ||
|
|
2217d0ab48 | ||
|
|
e6e05681a3 | ||
|
|
feef46f9f3 | ||
|
|
243514f601 | ||
|
|
72eb5c5626 | ||
|
|
eb4f806a44 | ||
|
|
b4ab51e307 | ||
|
|
a321abf8a6 | ||
|
|
9573274039 | ||
|
|
4f7b1332e2 | ||
|
|
0dac14abd4 | ||
|
|
aa0eac8ae8 | ||
|
|
8fead4dfbf | ||
|
|
ac4b49a7f9 | ||
|
|
fc22232f14 | ||
|
|
9ddd44bf56 | ||
|
|
8587911cf6 | ||
|
|
d7300d50d7 | ||
|
|
cc950a2da2 | ||
|
|
8d6640159a | ||
|
|
bba77749c3 | ||
|
|
3e9a66c8ff | ||
|
|
548b9d9e0e | ||
|
|
0d3967baee | ||
|
|
6ed806eebb | ||
|
|
3b6a35b2c4 | ||
|
|
62e612f85f | ||
|
|
b375b7f0ff | ||
|
|
c158ae2622 | ||
|
|
698494626f | ||
|
|
93cefe7ef0 | ||
|
|
8a326c4089 | ||
|
|
0c5410f429 | ||
|
|
0b05b9b235 | ||
|
|
59d8a988bd | ||
|
|
6d08cfb25a | ||
|
|
53a5ee2a6e |
@@ -1,43 +0,0 @@
|
||||
"""add timestamps to user table
|
||||
|
||||
Revision ID: 27fb147a843f
|
||||
Revises: b5c4d7e8f9a1
|
||||
Create Date: 2026-03-08 17:18:40.828644
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "27fb147a843f"
|
||||
down_revision = "b5c4d7e8f9a1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "updated_at")
|
||||
op.drop_column("user", "created_at")
|
||||
@@ -0,0 +1,117 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: b5c4d7e8f9a1
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "b5c4d7e8f9a1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -29,6 +29,7 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -121,6 +122,7 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -1612,6 +1614,102 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -339,20 +339,15 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -3065,6 +3060,65 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -174,21 +173,6 @@ def _get_accepted_user_where_clause(
|
||||
return where_clause
|
||||
|
||||
|
||||
def get_all_accepted_users(
|
||||
db_session: Session,
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
"""Returns all accepted users without pagination.
|
||||
Uses the same filtering as the paginated endpoint but without
|
||||
search, role, or active filters."""
|
||||
stmt = select(User)
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
include_external=include_external,
|
||||
)
|
||||
stmt = stmt.where(*where_clause).order_by(User.email)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_page_of_filtered_users(
|
||||
db_session: Session,
|
||||
page_size: int,
|
||||
@@ -374,28 +358,3 @@ def delete_user_from_db(
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
remove_user_from_invited_users(user_to_delete.email)
|
||||
|
||||
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
UserGroup.name,
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
result[user_id].append((group_id, group_name))
|
||||
return result
|
||||
|
||||
248
backend/onyx/db/voice.py
Normal file
248
backend/onyx/db/voice.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
MIN_VOICE_PLAYBACK_SPEED = 0.5
|
||||
MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For all fields, None means "don't update this field".
|
||||
"""
|
||||
values: dict[str, bool | float] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
|
||||
)
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -66,6 +66,11 @@ class OnyxErrorCode(Enum):
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Payload (413)
|
||||
# ------------------------------------------------------------------
|
||||
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -3782,6 +3782,16 @@
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet-v2": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "v2"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet-v2@20241022": {
|
||||
"display_name": "Claude Sonnet 3.5 v2",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet@20240620": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -92,27 +92,8 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
split_at = limit
|
||||
|
||||
chunk = text[:split_at]
|
||||
|
||||
# If splitting inside an unclosed code fence, try to back up and
|
||||
# split before the opening ``` so the code block stays in one piece.
|
||||
open_fences = chunk.count("```")
|
||||
if open_fences % 2 == 1:
|
||||
last_fence = chunk.rfind("```")
|
||||
# Find a newline before the fence to get a clean break
|
||||
split_before = text.rfind("\n", 0, last_fence)
|
||||
if split_before > 0:
|
||||
# Back up to before the code block
|
||||
chunk = text[:split_before]
|
||||
text = text[split_before:].lstrip()
|
||||
else:
|
||||
# Code block itself exceeds the limit — no choice but to
|
||||
# split inside it. Close the fence here, reopen in the next.
|
||||
chunk += "\n```"
|
||||
text = "```\n" + text[split_at:].lstrip()
|
||||
else:
|
||||
text = text[split_at:].lstrip()
|
||||
|
||||
chunks.append(chunk)
|
||||
text = text[split_at:].lstrip() # Remove leading spaces from the next chunk
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -18,18 +18,15 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -44,51 +41,6 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -205,20 +157,6 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -269,7 +207,6 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -294,16 +231,6 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -358,7 +285,6 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -439,13 +442,97 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
# Rate limit: max tokens per user per window
|
||||
WS_TOKEN_RATE_LIMIT_MAX = 10
|
||||
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
|
||||
|
||||
|
||||
class WsTokenRateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds the WS token generation rate limit."""
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
|
||||
Raises:
|
||||
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# Atomically increment and check rate limit to avoid TOCTOU races
|
||||
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_limit_key)
|
||||
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
|
||||
results = await pipe.execute()
|
||||
new_count = results[0]
|
||||
|
||||
if new_count > WS_TOKEN_RATE_LIMIT_MAX:
|
||||
# Over limit — decrement back since we won't use this slot
|
||||
await redis.decr(rate_limit_key)
|
||||
logger.warning(f"WS token rate limit exceeded for user {user_id}")
|
||||
raise WsTokenRateLimitExceeded(
|
||||
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
|
||||
)
|
||||
|
||||
# Store the actual token
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
# diagnostic logging for lock errors
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -129,6 +130,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -85,6 +85,11 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -164,6 +169,9 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -240,6 +248,12 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -67,9 +67,7 @@ from onyx.db.user_preferences import update_user_role
|
||||
from onyx.db.user_preferences import update_user_shortcut_enabled
|
||||
from onyx.db.user_preferences import update_user_temperature_override_enabled
|
||||
from onyx.db.user_preferences import update_user_theme_preference
|
||||
from onyx.db.users import batch_get_user_groups
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_all_accepted_users
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.db.users import get_page_of_filtered_users
|
||||
from onyx.db.users import get_total_filtered_users_count
|
||||
@@ -100,7 +98,6 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -206,51 +203,14 @@ def list_accepted_users(
|
||||
total_items=0,
|
||||
)
|
||||
|
||||
user_ids = [user.id for user in filtered_accepted_users]
|
||||
groups_by_user = batch_get_user_groups(db_session, user_ids)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
FullUserSnapshot.from_user_model(
|
||||
user,
|
||||
groups=[
|
||||
UserGroupInfo(id=gid, name=gname)
|
||||
for gid, gname in groups_by_user.get(user.id, [])
|
||||
],
|
||||
)
|
||||
for user in filtered_accepted_users
|
||||
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
|
||||
],
|
||||
total_items=total_accepted_users_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/manage/users/accepted/all", tags=PUBLIC_API_TAGS)
|
||||
def list_all_accepted_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullUserSnapshot]:
|
||||
"""Returns all accepted users without pagination.
|
||||
Used by the admin Users page for client-side filtering/sorting."""
|
||||
users = get_all_accepted_users(db_session=db_session)
|
||||
|
||||
if not users:
|
||||
return []
|
||||
|
||||
user_ids = [user.id for user in users]
|
||||
groups_by_user = batch_get_user_groups(db_session, user_ids)
|
||||
|
||||
return [
|
||||
FullUserSnapshot.from_user_model(
|
||||
user,
|
||||
groups=[
|
||||
UserGroupInfo(id=gid, name=gname)
|
||||
for gid, gname in groups_by_user.get(user.id, [])
|
||||
],
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
|
||||
@router.get("/manage/users/invited", tags=PUBLIC_API_TAGS)
|
||||
def list_invited_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
@@ -309,10 +269,24 @@ def list_all_users(
|
||||
if accepted_page is None or invited_page is None or slack_users_page is None:
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot.from_user_model(user) for user in accepted_users
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
],
|
||||
slack_users=[
|
||||
FullUserSnapshot.from_user_model(user) for user in slack_users
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
||||
accepted_pages=1,
|
||||
@@ -322,10 +296,26 @@ def list_all_users(
|
||||
|
||||
# Otherwise, return paginated results
|
||||
return AllUsersResponse(
|
||||
accepted=[FullUserSnapshot.from_user_model(user) for user in accepted_users][
|
||||
accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE
|
||||
],
|
||||
slack_users=[FullUserSnapshot.from_user_model(user) for user in slack_users][
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
slack_users_page
|
||||
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
||||
* USERS_PAGE_SIZE
|
||||
|
||||
0
backend/onyx/server/manage/voice/__init__.py
Normal file
0
backend/onyx/server/manage/voice/__init__.py
Normal file
315
backend/onyx/server/manage/voice/api.py
Normal file
315
backend/onyx/server/manage/voice/api.py
Normal file
@@ -0,0 +1,315 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.models import VoiceOption
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
|
||||
"""Validate and normalize provider api_base / target URI."""
|
||||
if api_base is None:
|
||||
return None
|
||||
|
||||
allow_private_network = provider_type.lower() == "azure"
|
||||
try:
|
||||
return validate_outbound_http_url(
|
||||
api_base, allow_private_network=allow_private_network
|
||||
)
|
||||
except (ValueError, SSRFException) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid target URI: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
async def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
# Validate credentials before committing - rollback on failure
|
||||
try:
|
||||
voice_provider = get_voice_provider(provider)
|
||||
await voice_provider.validate_credentials()
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Voice provider credential validation failed on save: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connection test failed. Please verify your API key and settings.",
|
||||
) from e
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
async def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Test a voice provider connection by making a real API call."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
# Validate credentials with a real API call
|
||||
try:
|
||||
await provider.validate_credentials()
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Voice provider connection test failed: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connection test failed. Please verify your API key and settings.",
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
95
backend/onyx/server/manage/voice/models.py
Normal file
95
backend/onyx/server/manage/voice/models.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpdateSuccess(BaseModel):
|
||||
"""Simple status response for voice provider actions."""
|
||||
|
||||
status: str = "ok"
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""Voice option returned by voice providers."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
250
backend/onyx/server/manage/voice/user_api.py
Normal file
250
backend/onyx/server/manage/voice/user_api.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
# Chunk size for streaming uploads (8KB)
|
||||
UPLOAD_READ_CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
class VoiceStatusResponse(BaseModel):
|
||||
stt_enabled: bool
|
||||
tts_enabled: bool
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_voice_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceStatusResponse:
|
||||
"""Check whether STT and TTS providers are configured and ready."""
|
||||
stt_provider = fetch_default_stt_provider(db_session)
|
||||
tts_provider = fetch_default_tts_provider(db_session)
|
||||
return VoiceStatusResponse(
|
||||
stt_enabled=stt_provider is not None and stt_provider.api_key is not None,
|
||||
tts_enabled=tts_provider is not None and tts_provider.api_key is not None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
|
||||
total += len(chunk)
|
||||
if total > MAX_AUDIO_SIZE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
audio_data = b"".join(chunks)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Transcription failed. Please try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice or provider default
|
||||
final_voice = voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
Rate limited to 10 tokens per minute per user.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
await store_ws_token(token, str(user.id))
|
||||
except WsTokenRateLimitExceeded:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
"Too many token requests. Please wait before requesting another.",
|
||||
)
|
||||
return WSTokenResponse(token=token)
|
||||
860
backend/onyx/server/manage/voice/websocket_api.py
Normal file
860
backend/onyx/server/manage/voice/websocket_api.py
Normal file
@@ -0,0 +1,860 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
# WebSocket size limits to prevent memory exhaustion attacks
|
||||
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
|
||||
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
|
||||
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
|
||||
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -1,4 +1,3 @@
|
||||
import datetime
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
@@ -32,38 +31,21 @@ class MinimalUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class UserGroupInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
groups: list[UserGroupInfo]
|
||||
|
||||
@classmethod
|
||||
def from_user_model(
|
||||
cls,
|
||||
user: User,
|
||||
groups: list[UserGroupInfo] | None = None,
|
||||
) -> "FullUserSnapshot":
|
||||
def from_user_model(cls, user: User) -> "FullUserSnapshot":
|
||||
return cls(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,44 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed.
|
||||
SSRFException: If URL fails SSRF checks.
|
||||
"""
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must contain a hostname")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFException("URLs with embedded credentials are not allowed.")
|
||||
|
||||
hostname = parsed.hostname.lower()
|
||||
if hostname in BLOCKED_HOSTNAMES:
|
||||
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
|
||||
|
||||
if not allow_private_network:
|
||||
_validate_and_resolve_url(normalized_url)
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
0
backend/onyx/voice/__init__.py
Normal file
0
backend/onyx/voice/__init__.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
182
backend/onyx/voice/interface.py
Normal file
182
backend/onyx/voice/interface.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_credentials(self) -> None:
|
||||
"""
|
||||
Validate that the provider credentials are correct by making a
|
||||
lightweight API call. Raises on failure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
0
backend/onyx/voice/providers/__init__.py
Normal file
0
backend/onyx/voice/providers/__init__.py
Normal file
626
backend/onyx/voice/providers/azure.py
Normal file
626
backend/onyx/voice/providers/azure.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""Azure Speech Services voice provider for STT and TTS.
|
||||
|
||||
Azure supports:
|
||||
- **STT**: Batch transcription via REST API (audio/wav POST) and real-time
|
||||
streaming via the Azure Speech SDK (push audio stream with continuous
|
||||
recognition). The SDK handles VAD natively through its recognizing/recognized
|
||||
events.
|
||||
- **TTS**: SSML-based synthesis via REST API (streaming response) and real-time
|
||||
synthesis via the Speech SDK. Text is escaped with ``xml.sax.saxutils.escape``
|
||||
and attributes with ``quoteattr`` to prevent SSML injection.
|
||||
|
||||
Both modes support Azure cloud endpoints (region-based URLs) and self-hosted
|
||||
Speech containers (custom endpoint URLs). The ``speech_region`` is validated to
|
||||
contain only ``[a-z0-9-]`` to prevent URL injection.
|
||||
|
||||
The Azure Speech SDK (``azure-cognitiveservices-speech``) is an optional C
|
||||
extension dependency — it is imported lazily inside streaming methods so the
|
||||
provider can still be instantiated and used for REST-based operations without it.
|
||||
|
||||
See https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/
|
||||
for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# SSML namespace — W3C standard for Speech Synthesis Markup Language.
|
||||
# This is a fixed W3C specification and will not change.
|
||||
SSML_NAMESPACE = "http://www.w3.org/2001/10/synthesis"
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
raw_speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.speech_region = self._validate_speech_region(raw_speech_region)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_cloud_url(uri: str | None) -> bool:
|
||||
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
|
||||
if not uri:
|
||||
return False
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return False
|
||||
return hostname.endswith(
|
||||
(
|
||||
".speech.microsoft.com",
|
||||
".api.cognitive.microsoft.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stt_tts_match = re.match(
|
||||
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
|
||||
)
|
||||
if stt_tts_match:
|
||||
return stt_tts_match.group(1)
|
||||
|
||||
api_match = re.match(
|
||||
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
|
||||
)
|
||||
if api_match:
|
||||
return api_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_speech_region(speech_region: str) -> str:
|
||||
normalized_region = speech_region.strip().lower()
|
||||
if not normalized_region:
|
||||
return ""
|
||||
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
|
||||
raise ValueError(
|
||||
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
|
||||
)
|
||||
return normalized_region
|
||||
|
||||
def _get_stt_url(self) -> str:
|
||||
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
def _get_tts_url(self) -> str:
|
||||
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
def _is_self_hosted(self) -> bool:
|
||||
"""Check if using self-hosted container vs Azure cloud."""
|
||||
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT (cloud mode)")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = self._get_stt_url()
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS (cloud mode)")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = self._get_tts_url()
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate Azure credentials by listing available voices."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required (cloud mode)")
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
if self._is_self_hosted():
|
||||
url = f"{(self.api_base or '').rstrip('/')}/cognitiveservices/voices/list"
|
||||
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"Azure credential validation failed: {error_text}"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming transcription (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming TTS (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
858
backend/onyx/voice/providers/elevenlabs.py
Normal file
858
backend/onyx/voice/providers/elevenlabs.py
Normal file
@@ -0,0 +1,858 @@
|
||||
"""ElevenLabs voice provider for STT and TTS.
|
||||
|
||||
ElevenLabs supports:
|
||||
- **STT**: Scribe API (batch via REST, streaming via WebSocket with Scribe v2 Realtime).
|
||||
The streaming endpoint sends base64-encoded PCM16 audio chunks and receives JSON
|
||||
transcript messages (partial_transcript, committed_transcript, utterance_end).
|
||||
- **TTS**: Text-to-speech via REST streaming and WebSocket stream-input.
|
||||
The WebSocket variant accepts incremental text chunks and returns audio in order,
|
||||
enabling low-latency playback before the full text is available.
|
||||
|
||||
See https://elevenlabs.io/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Default ElevenLabs API base URL
|
||||
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
|
||||
|
||||
# Default sample rates for STT streaming
|
||||
DEFAULT_INPUT_SAMPLE_RATE = 24000 # What the browser frontend sends
|
||||
DEFAULT_TARGET_SAMPLE_RATE = 16000 # What ElevenLabs Scribe expects
|
||||
|
||||
# Default streaming TTS output format
|
||||
DEFAULT_TTS_OUTPUT_FORMAT = "mp3_44100_64"
|
||||
|
||||
# Default TTS voice settings
|
||||
DEFAULT_VOICE_STABILITY = 0.5
|
||||
DEFAULT_VOICE_SIMILARITY_BOOST = 0.75
|
||||
|
||||
# Chunk length schedule for streaming TTS (optimized for real-time playback)
|
||||
DEFAULT_CHUNK_LENGTH_SCHEDULE = [120, 160, 250, 290]
|
||||
|
||||
# Default STT streaming VAD configuration
|
||||
DEFAULT_VAD_SILENCE_THRESHOLD_SECS = 1.0
|
||||
DEFAULT_VAD_THRESHOLD = 0.4
|
||||
DEFAULT_MIN_SPEECH_DURATION_MS = 100
|
||||
DEFAULT_MIN_SILENCE_DURATION_MS = 300
|
||||
|
||||
|
||||
class ElevenLabsSTTMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs Scribe Realtime STT API."""
|
||||
|
||||
SESSION_STARTED = "session_started"
|
||||
PARTIAL_TRANSCRIPT = "partial_transcript"
|
||||
COMMITTED_TRANSCRIPT = "committed_transcript"
|
||||
UTTERANCE_END = "utterance_end"
|
||||
SESSION_ENDED = "session_ended"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ElevenLabsTTSMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs stream-input TTS API."""
|
||||
|
||||
AUDIO = "audio"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate: int = DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code: str = "en",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters.
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection.
|
||||
# These params are part of the ElevenLabs Scribe Realtime API contract:
|
||||
# https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs={DEFAULT_VAD_SILENCE_THRESHOLD_SECS}"
|
||||
f"&vad_threshold={DEFAULT_VAD_THRESHOLD}"
|
||||
f"&min_speech_duration_ms={DEFAULT_MIN_SPEECH_DURATION_MS}"
|
||||
f"&min_silence_duration_ms={DEFAULT_MIN_SILENCE_DURATION_MS}"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == ElevenLabsSTTMessageType.ERROR:
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle message types from ElevenLabs Scribe Realtime API.
|
||||
# See https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
if msg_type == ElevenLabsSTTMessageType.SESSION_STARTED:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.PARTIAL_TRANSCRIPT:
|
||||
# Interim result — updated as more audio is processed
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT:
|
||||
# Final transcript for the current utterance (VAD detected end)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.UTTERANCE_END:
|
||||
# VAD detected end of speech (may carry text or be empty)
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.SESSION_ENDED:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
api_base: str | None = None,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self.speed = speed
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming.
|
||||
# Note: API key is sent via header only (not in body to avoid log exposure).
|
||||
# See https://elevenlabs.io/docs/api-reference/text-to-speech/stream-input
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": self.speed,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": DEFAULT_CHUNK_LENGTH_SCHEDULE,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": speed,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate ElevenLabs API key by fetching user info."""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required")
|
||||
|
||||
headers = {"xi-api-key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_base}/v1/user", headers=headers
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"ElevenLabs credential validation failed: {error_text}"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model.
|
||||
# Frontend sends PCM16 at DEFAULT_INPUT_SAMPLE_RATE (24kHz),
|
||||
# but ElevenLabs expects DEFAULT_TARGET_SAMPLE_RATE (16kHz).
|
||||
# The transcriber resamples automatically.
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate=DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code="en",
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
output_format=DEFAULT_TTS_OUTPUT_FORMAT,
|
||||
api_base=self.api_base,
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
626
backend/onyx/voice/providers/openai.py
Normal file
626
backend/onyx/voice/providers/openai.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""OpenAI voice provider for STT and TTS.
|
||||
|
||||
OpenAI supports:
|
||||
- **STT**: Whisper (batch transcription via REST) and Realtime API (streaming
|
||||
transcription via WebSocket with server-side VAD). Audio is sent as base64-encoded
|
||||
PCM16 at 24kHz mono. The Realtime API returns transcript deltas and completed
|
||||
transcription events per VAD-detected utterance.
|
||||
- **TTS**: HTTP streaming endpoint that returns audio chunks progressively.
|
||||
Supported models: tts-1 (standard) and tts-1-hd (high quality).
|
||||
|
||||
See https://platform.openai.com/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Default OpenAI API base URL
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
|
||||
|
||||
|
||||
class OpenAIRealtimeMessageType(StrEnum):
|
||||
"""Message types from OpenAI Realtime transcription API."""
|
||||
|
||||
ERROR = "error"
|
||||
SPEECH_STARTED = "input_audio_buffer.speech_started"
|
||||
SPEECH_STOPPED = "input_audio_buffer.speech_stopped"
|
||||
BUFFER_COMMITTED = "input_audio_buffer.committed"
|
||||
TRANSCRIPTION_DELTA = "conversation.item.input_audio_transcription.delta"
|
||||
TRANSCRIPTION_COMPLETED = "conversation.item.input_audio_transcription.completed"
|
||||
SESSION_CREATED = "transcription_session.created"
|
||||
SESSION_UPDATED = "transcription_session.updated"
|
||||
ITEM_CREATED = "conversation.item.created"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "whisper-1",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = f"{ws_base}/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == OpenAIRealtimeMessageType.ERROR:
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == OpenAIRealtimeMessageType.SPEECH_STARTED:
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.SPEECH_STOPPED:
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.BUFFER_COMMITTED:
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA:
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_COMPLETED:
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
OpenAIRealtimeMessageType.SESSION_CREATED,
|
||||
OpenAIRealtimeMessageType.SESSION_UPDATED,
|
||||
OpenAIRealtimeMessageType.ITEM_CREATED,
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for *new* transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
transcript_before_commit = self._accumulated_transcript
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript != transcript_before_commit:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate OpenAI API key by listing models."""
|
||||
client = self._get_client()
|
||||
await client.models.list()
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -67,6 +67,8 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -1020,7 +1022,7 @@ toolz==1.1.0
|
||||
# dask
|
||||
# distributed
|
||||
# partd
|
||||
tornado==6.5.5
|
||||
tornado==6.5.2
|
||||
# via distributed
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
|
||||
@@ -466,7 +466,7 @@ tokenizers==0.21.4
|
||||
# via
|
||||
# cohere
|
||||
# litellm
|
||||
tornado==6.5.5
|
||||
tornado==6.5.2
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
|
||||
507
backend/tests/unit/onyx/db/test_voice.py
Normal file
507
backend/tests/unit/onyx/db/test_voice.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""Unit tests for onyx.db.voice module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import MAX_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import MIN_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
def _make_voice_provider(
|
||||
id: int = 1,
|
||||
name: str = "Test Provider",
|
||||
provider_type: str = "openai",
|
||||
is_default_stt: bool = False,
|
||||
is_default_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create a VoiceProvider instance for testing."""
|
||||
provider = VoiceProvider()
|
||||
provider.id = id
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.is_default_stt = is_default_stt
|
||||
provider.is_default_tts = is_default_tts
|
||||
provider.api_key = None
|
||||
provider.api_base = None
|
||||
provider.custom_config = None
|
||||
provider.stt_model = None
|
||||
provider.tts_model = None
|
||||
provider.default_voice = None
|
||||
return provider
|
||||
|
||||
|
||||
class TestFetchVoiceProviders:
|
||||
"""Tests for fetch_voice_providers."""
|
||||
|
||||
def test_returns_all_providers(self, mock_db_session: MagicMock) -> None:
|
||||
providers = [
|
||||
_make_voice_provider(id=1, name="Provider A"),
|
||||
_make_voice_provider(id=2, name="Provider B"),
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = providers
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == providers
|
||||
mock_db_session.scalars.assert_called_once()
|
||||
|
||||
def test_returns_empty_list_when_no_providers(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestFetchVoiceProviderById:
|
||||
"""Tests for fetch_voice_provider_by_id."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 1)
|
||||
|
||||
assert result is provider
|
||||
mock_db_session.scalar.assert_called_once()
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchDefaultProviders:
|
||||
"""Tests for fetch_default_stt_provider and fetch_default_tts_provider."""
|
||||
|
||||
def test_fetch_default_stt_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_stt_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fetch_default_tts_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_tts_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchVoiceProviderByType:
|
||||
"""Tests for fetch_voice_provider_by_type."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1, provider_type="openai")
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "openai")
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertVoiceProvider:
|
||||
"""Tests for upsert_voice_provider."""
|
||||
|
||||
def test_creates_new_provider_when_no_id(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=None,
|
||||
name="New Provider",
|
||||
provider_type="openai",
|
||||
api_key="test-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert added_obj.name == "New Provider"
|
||||
assert added_obj.provider_type == "openai"
|
||||
|
||||
def test_updates_existing_provider(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1, name="Old Name")
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Updated Name",
|
||||
provider_type="elevenlabs",
|
||||
api_key="new-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
assert existing_provider.name == "Updated Name"
|
||||
assert existing_provider.provider_type == "elevenlabs"
|
||||
|
||||
def test_raises_when_provider_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=999,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_does_not_update_api_key_when_not_changed(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
existing_provider.api_key = "original-key" # type: ignore[assignment]
|
||||
original_api_key = existing_provider.api_key
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key="new-key",
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
# api_key should remain unchanged (same object reference)
|
||||
assert existing_provider.api_key is original_api_key
|
||||
|
||||
def test_activates_stt_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_stt=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_stt is True
|
||||
|
||||
def test_activates_tts_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_tts=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_tts is True
|
||||
|
||||
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
delete_voice_provider(mock_db_session, 999)
|
||||
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSetDefaultProviders:
|
||||
"""Tests for set_default_stt_provider and set_default_tts_provider."""
|
||||
|
||||
def test_set_default_stt_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_stt is True
|
||||
|
||||
def test_set_default_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_set_default_tts_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_tts is True
|
||||
|
||||
def test_set_default_tts_provider_updates_model_when_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(
|
||||
db_session=mock_db_session, provider_id=1, tts_model="tts-1-hd"
|
||||
)
|
||||
|
||||
assert result.tts_model == "tts-1-hd"
|
||||
|
||||
def test_set_default_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDeactivateProviders:
|
||||
"""Tests for deactivate_stt_provider and deactivate_tts_provider."""
|
||||
|
||||
def test_deactivate_stt_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_stt is False
|
||||
|
||||
def test_deactivate_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_deactivate_tts_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_tts is False
|
||||
|
||||
def test_deactivate_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestUpdateUserVoiceSettings:
|
||||
"""Tests for update_user_voice_settings."""
|
||||
|
||||
def test_updates_auto_send(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_send=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_auto_playback(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_playback=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_playback_speed_within_range(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=1.5)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
def test_clamps_playback_speed_to_min(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=0.1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MIN_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_clamps_playback_speed_to_max(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=5.0)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MAX_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_updates_multiple_settings(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(
|
||||
mock_db_session,
|
||||
user_id,
|
||||
auto_send=True,
|
||||
auto_playback=False,
|
||||
playback_speed=1.25,
|
||||
)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_no_settings_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id)
|
||||
|
||||
mock_db_session.execute.assert_not_called()
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSpeedClampingLogic:
|
||||
"""Tests for the speed clamping constants and logic."""
|
||||
|
||||
def test_min_speed_constant(self) -> None:
|
||||
assert MIN_VOICE_PLAYBACK_SPEED == 0.5
|
||||
|
||||
def test_max_speed_constant(self) -> None:
|
||||
assert MAX_VOICE_PLAYBACK_SPEED == 2.0
|
||||
|
||||
def test_clamping_formula(self) -> None:
|
||||
"""Verify the clamping formula used in update_user_voice_settings."""
|
||||
test_cases = [
|
||||
(0.1, MIN_VOICE_PLAYBACK_SPEED),
|
||||
(0.5, 0.5),
|
||||
(1.0, 1.0),
|
||||
(1.5, 1.5),
|
||||
(2.0, 2.0),
|
||||
(3.0, MAX_VOICE_PLAYBACK_SPEED),
|
||||
]
|
||||
for speed, expected in test_cases:
|
||||
clamped = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, speed)
|
||||
)
|
||||
assert (
|
||||
clamped == expected
|
||||
), f"speed={speed} expected={expected} got={clamped}"
|
||||
@@ -1,204 +0,0 @@
|
||||
"""Tests for Slack channel reference resolution and tag filtering
|
||||
in handle_regular_answer.py."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_client_with_channels(
|
||||
channel_map: dict[str, str],
|
||||
) -> MagicMock:
|
||||
"""Return a mock WebClient where conversations_info resolves IDs to names."""
|
||||
client = MagicMock()
|
||||
|
||||
def _conversations_info(channel: str) -> MagicMock:
|
||||
if channel in channel_map:
|
||||
resp = MagicMock()
|
||||
resp.validate = MagicMock()
|
||||
resp.__getitem__ = lambda _self, key: {
|
||||
"channel": {
|
||||
"name": channel_map[channel],
|
||||
"is_im": False,
|
||||
"is_mpim": False,
|
||||
}
|
||||
}[key]
|
||||
return resp
|
||||
raise SlackApiError("channel_not_found", response=MagicMock())
|
||||
|
||||
client.conversations_info = _conversations_info
|
||||
return client
|
||||
|
||||
|
||||
def _mock_logger() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SLACK_CHANNEL_REF_PATTERN regex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackChannelRefPattern:
|
||||
def test_matches_bare_channel_id(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
|
||||
assert matches == [("C097NBWMY8Y", "")]
|
||||
|
||||
def test_matches_channel_id_with_name(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
|
||||
assert matches == [("C097NBWMY8Y", "eng-infra")]
|
||||
|
||||
def test_matches_multiple_channels(self) -> None:
|
||||
msg = "compare <#C111AAA> and <#C222BBB|general>"
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
|
||||
assert len(matches) == 2
|
||||
assert ("C111AAA", "") in matches
|
||||
assert ("C222BBB", "general") in matches
|
||||
|
||||
def test_no_match_on_plain_text(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
|
||||
assert matches == []
|
||||
|
||||
def test_no_match_on_user_mention(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
|
||||
assert matches == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_channel_references tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveChannelReferences:
|
||||
def test_resolves_bare_channel_id_via_api(self) -> None:
|
||||
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summary of <#C097NBWMY8Y> this week",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summary of #eng-infra this week"
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
|
||||
|
||||
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#C097NBWMY8Y|eng-infra> for updates",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "check #eng-infra for updates"
|
||||
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
|
||||
# Should NOT have called the API since name was in the markup
|
||||
client.conversations_info.assert_not_called()
|
||||
|
||||
def test_multiple_channels(self) -> None:
|
||||
client = _mock_client_with_channels(
|
||||
{
|
||||
"C111AAA": "eng-infra",
|
||||
"C222BBB": "eng-general",
|
||||
}
|
||||
)
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#eng-general" in message
|
||||
assert "<#" not in message
|
||||
assert len(tags) == 2
|
||||
tag_values = {t.tag_value for t in tags}
|
||||
assert tag_values == {"eng-infra", "eng-general"}
|
||||
|
||||
def test_no_channel_references_returns_unchanged(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="just a normal message with no channels",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "just a normal message with no channels"
|
||||
assert tags == []
|
||||
|
||||
def test_api_failure_skips_channel_gracefully(self) -> None:
|
||||
# Client that fails for all channel lookups
|
||||
client = _mock_client_with_channels({})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Message should remain unchanged for the failed channel
|
||||
assert "<#CBADID123>" in message
|
||||
assert tags == []
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
def test_partial_failure_resolves_what_it_can(self) -> None:
|
||||
# Only one of two channels resolves
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "<#CBADID123>" in message # failed one stays raw
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_duplicate_channel_produces_single_tag(self) -> None:
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summarize <#C111AAA> and compare with <#C111AAA>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summarize #eng-infra and compare with #eng-infra"
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_mixed_pipe_and_bare_formats(self) -> None:
|
||||
client = _mock_client_with_channels({"C222BBB": "random"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="see <#C111AAA|eng-infra> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#random" in message
|
||||
assert len(tags) == 2
|
||||
@@ -7,7 +7,6 @@ import timeago # type: ignore
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.onyxbot.slack.blocks import _build_documents_blocks
|
||||
from onyx.onyxbot.slack.blocks import _split_text
|
||||
|
||||
|
||||
def _make_saved_doc(updated_at: datetime | None) -> SavedSearchDoc:
|
||||
@@ -70,78 +69,3 @@ def test_build_documents_blocks_formats_naive_timestamp(
|
||||
formatted_timestamp: datetime = captured["doc"]
|
||||
expected_timestamp: datetime = naive_timestamp.replace(tzinfo=pytz.utc)
|
||||
assert formatted_timestamp == expected_timestamp
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _split_text tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSplitText:
|
||||
def test_short_text_returns_single_chunk(self) -> None:
|
||||
result = _split_text("hello world", limit=100)
|
||||
assert result == ["hello world"]
|
||||
|
||||
def test_splits_at_space_boundary(self) -> None:
|
||||
text = "aaa bbb ccc ddd"
|
||||
result = _split_text(text, limit=8)
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_code_block_not_split_when_fits(self) -> None:
|
||||
text = "before ```code here``` after"
|
||||
result = _split_text(text, limit=100)
|
||||
assert result == [text]
|
||||
|
||||
def test_code_block_split_backs_up_before_fence(self) -> None:
|
||||
# Build text where the split point falls inside a code block,
|
||||
# but the code block itself fits within the limit. The split
|
||||
# should back up to before the opening ``` so the block stays intact.
|
||||
before = "some intro text here " * 5 + "\n" # ~105 chars
|
||||
code_content = "x " * 20 # ~40 chars of code
|
||||
text = f"{before}```\n{code_content}\n```\nafter"
|
||||
# limit=120 means the initial split lands inside the code block
|
||||
# but the code block (~50 chars) fits in the next chunk
|
||||
result = _split_text(text, limit=120)
|
||||
|
||||
assert len(result) >= 2
|
||||
# Every chunk must have balanced code fences (0 or 2)
|
||||
for chunk in result:
|
||||
fence_count = chunk.count("```")
|
||||
assert (
|
||||
fence_count % 2 == 0
|
||||
), f"Unbalanced code fences in chunk: {chunk[:80]}..."
|
||||
# The code block should be fully contained in one chunk
|
||||
code_chunks = [c for c in result if "```" in c]
|
||||
assert len(code_chunks) == 1, "Code block should not be split across chunks"
|
||||
|
||||
def test_no_code_fences_splits_normally(self) -> None:
|
||||
text = "word " * 100 # 500 chars
|
||||
result = _split_text(text, limit=100)
|
||||
assert len(result) >= 5
|
||||
for chunk in result:
|
||||
fence_count = chunk.count("```")
|
||||
assert fence_count == 0
|
||||
|
||||
def test_code_block_exceeding_limit_falls_back_to_close_reopen(self) -> None:
|
||||
# When the code block itself is bigger than the limit, we can't
|
||||
# avoid splitting inside it — verify fences are still balanced.
|
||||
code_content = "x " * 100 # ~200 chars
|
||||
text = f"```\n{code_content}\n```"
|
||||
result = _split_text(text, limit=80)
|
||||
|
||||
assert len(result) >= 2
|
||||
for chunk in result:
|
||||
fence_count = chunk.count("```")
|
||||
assert (
|
||||
fence_count % 2 == 0
|
||||
), f"Unbalanced code fences in chunk: {chunk[:80]}..."
|
||||
|
||||
def test_all_content_preserved_after_split(self) -> None:
|
||||
text = "intro paragraph and more text here\n```\nprint('hello')\n```\nconclusion here"
|
||||
result = _split_text(text, limit=50)
|
||||
|
||||
# Key content should appear somewhere across the chunks
|
||||
joined = " ".join(result)
|
||||
assert "intro" in joined
|
||||
assert "print('hello')" in joined
|
||||
assert "conclusion" in joined
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.api import _validate_voice_api_base
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_allows_private_for_azure() -> None:
|
||||
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
|
||||
assert validated == "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("azure", "http://metadata.google.internal/")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_returns_none_for_none() -> None:
|
||||
assert _validate_voice_api_base("openai", None) is None
|
||||
@@ -1,54 +0,0 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
|
||||
def _mock_user(
|
||||
personal_name: str | None = "Test User",
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = "test@example.com"
|
||||
user.role = UserRole.BASIC
|
||||
user.is_active = True
|
||||
user.password_configured = True
|
||||
user.personal_name = personal_name
|
||||
user.created_at = created_at or datetime.datetime(
|
||||
2025, 1, 1, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def test_from_user_model_includes_new_fields() -> None:
|
||||
user = _mock_user(personal_name="Alice")
|
||||
groups = [UserGroupInfo(id=1, name="Engineering")]
|
||||
|
||||
snapshot = FullUserSnapshot.from_user_model(user, groups=groups)
|
||||
|
||||
assert snapshot.personal_name == "Alice"
|
||||
assert snapshot.created_at == user.created_at
|
||||
assert snapshot.updated_at == user.updated_at
|
||||
assert snapshot.groups == groups
|
||||
|
||||
|
||||
def test_from_user_model_defaults_groups_to_empty() -> None:
|
||||
user = _mock_user()
|
||||
snapshot = FullUserSnapshot.from_user_model(user)
|
||||
|
||||
assert snapshot.groups == []
|
||||
|
||||
|
||||
def test_from_user_model_personal_name_none() -> None:
|
||||
user = _mock_user(personal_name=None)
|
||||
snapshot = FullUserSnapshot.from_user_model(user)
|
||||
|
||||
assert snapshot.personal_name is None
|
||||
@@ -14,6 +14,7 @@ from onyx.utils.url import _is_ip_private_or_reserved
|
||||
from onyx.utils.url import _validate_and_resolve_url
|
||||
from onyx.utils.url import ssrf_safe_get
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
|
||||
class TestIsIpPrivateOrReserved:
|
||||
@@ -305,3 +306,22 @@ class TestSsrfSafeGet:
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
|
||||
class TestValidateOutboundHttpUrl:
|
||||
def test_rejects_private_ip_by_default(self) -> None:
|
||||
with pytest.raises(SSRFException, match="internal/private IP"):
|
||||
validate_outbound_http_url("http://127.0.0.1:8000")
|
||||
|
||||
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
|
||||
validated_url = validate_outbound_http_url(
|
||||
"http://127.0.0.1:8000", allow_private_network=True
|
||||
)
|
||||
assert validated_url == "http://127.0.0.1:8000"
|
||||
|
||||
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
|
||||
with pytest.raises(SSRFException, match="not allowed"):
|
||||
validate_outbound_http_url(
|
||||
"http://metadata.google.internal/latest",
|
||||
allow_private_network=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
def test_azure_provider_extracts_region_from_target_uri() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://westus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider.speech_region == "westus"
|
||||
|
||||
|
||||
def test_azure_provider_normalizes_uppercase_region() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "WestUS2"},
|
||||
)
|
||||
assert provider.speech_region == "westus2"
|
||||
|
||||
|
||||
def test_azure_provider_rejects_invalid_speech_region() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "westus/../../etc"},
|
||||
)
|
||||
194
backend/tests/unit/onyx/voice/providers/test_azure_ssml.py
Normal file
194
backend/tests/unit/onyx/voice/providers/test_azure_ssml.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
# --- _is_azure_cloud_url ---
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_speech_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_cognitive_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://westus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_custom_host() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url("https://my-custom-host.com/")
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_none() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url(None)
|
||||
|
||||
|
||||
# --- _extract_speech_region_from_uri ---
|
||||
|
||||
|
||||
def test_extract_region_from_tts_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_from_cognitive_api_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_custom_domain() -> None:
|
||||
"""Custom domains use resource name, not region — must use speech_region config."""
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://myresource.cognitiveservices.azure.com/"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_none() -> None:
|
||||
assert AzureVoiceProvider._extract_speech_region_from_uri(None) is None
|
||||
|
||||
|
||||
# --- _validate_speech_region ---
|
||||
|
||||
|
||||
def test_validate_region_normalizes_to_lowercase() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("WestUS2") == "westus2"
|
||||
|
||||
|
||||
def test_validate_region_accepts_hyphens() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("us-east-1") == "us-east-1"
|
||||
|
||||
|
||||
def test_validate_region_rejects_path_traversal() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("westus/../../etc")
|
||||
|
||||
|
||||
def test_validate_region_rejects_dots() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("west.us")
|
||||
|
||||
|
||||
# --- _pcm16_to_wav ---
|
||||
|
||||
|
||||
def test_pcm16_to_wav_produces_valid_wav() -> None:
|
||||
samples = [32767, -32768, 0, 1234]
|
||||
pcm_data = struct.pack(f"<{len(samples)}h", *samples)
|
||||
wav_bytes = AzureVoiceProvider._pcm16_to_wav(pcm_data, sample_rate=16000)
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 16000
|
||||
frames = wav_file.readframes(4)
|
||||
recovered = struct.unpack(f"<{len(samples)}h", frames)
|
||||
assert list(recovered) == samples
|
||||
|
||||
|
||||
# --- URL Construction ---
|
||||
|
||||
|
||||
def test_get_tts_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "eastus"}
|
||||
)
|
||||
assert (
|
||||
provider._get_tts_url()
|
||||
== "https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_get_stt_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "westus2"}
|
||||
)
|
||||
assert "westus2.stt.speech.microsoft.com" in provider._get_stt_url()
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted_strips_trailing_slash() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000/", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
# --- _is_self_hosted ---
|
||||
|
||||
|
||||
def test_is_self_hosted_true_for_custom_endpoint() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._is_self_hosted() is True
|
||||
|
||||
|
||||
def test_is_self_hosted_false_for_azure_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://eastus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider._is_self_hosted() is False
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
assert len(result) // 2 == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_empty_data() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
assert t._resample_pcm16(b"") == b""
|
||||
@@ -0,0 +1,117 @@
|
||||
import struct
|
||||
|
||||
from onyx.voice.providers.elevenlabs import _http_to_ws_url
|
||||
from onyx.voice.providers.elevenlabs import DEFAULT_ELEVENLABS_API_BASE
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsSTTMessageType
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.elevenlabs.io") == "wss://api.elevenlabs.io"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:8080") == "ws://localhost:8080"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_other_schemes() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
def test_http_to_ws_url_preserves_path() -> None:
|
||||
assert (
|
||||
_http_to_ws_url("https://api.elevenlabs.io/v1/tts")
|
||||
== "wss://api.elevenlabs.io/v1/tts"
|
||||
)
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_stt_message_type_compares_as_string() -> None:
|
||||
"""StrEnum members should work in string comparisons (e.g. from JSON)."""
|
||||
assert str(ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT) == "committed_transcript"
|
||||
assert isinstance(ElevenLabsSTTMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough_when_same_rate() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
"""24kHz -> 16kHz should produce fewer samples (ratio 3:2)."""
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
|
||||
assert len(output_samples) == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_clamps_to_int16_range() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [32767, -32768, 32767, -32768, 32767, -32768]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
for s in output_samples:
|
||||
assert -32768 <= s <= 32767
|
||||
|
||||
|
||||
# --- Provider Model Defaulting ---
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_stt_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", stt_model="invalid_model")
|
||||
assert provider.stt_model == "scribe_v1"
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_tts_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", tts_model="invalid_model")
|
||||
assert provider.tts_model == "eleven_multilingual_v2"
|
||||
|
||||
|
||||
def test_provider_accepts_valid_models() -> None:
|
||||
provider = ElevenLabsVoiceProvider(
|
||||
api_key="test", stt_model="scribe_v2_realtime", tts_model="eleven_turbo_v2_5"
|
||||
)
|
||||
assert provider.stt_model == "scribe_v2_realtime"
|
||||
assert provider.tts_model == "eleven_turbo_v2_5"
|
||||
|
||||
|
||||
def test_provider_defaults_api_base() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
assert provider.api_base == DEFAULT_ELEVENLABS_API_BASE
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -0,0 +1,97 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from onyx.voice.providers.openai import _create_wav_header
|
||||
from onyx.voice.providers.openai import _http_to_ws_url
|
||||
from onyx.voice.providers.openai import OpenAIRealtimeMessageType
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.openai.com") == "wss://api.openai.com"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:9090") == "ws://localhost:9090"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_ws() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_realtime_message_type_compares_as_string() -> None:
|
||||
assert str(OpenAIRealtimeMessageType.ERROR) == "error"
|
||||
assert (
|
||||
str(OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA)
|
||||
== "conversation.item.input_audio_transcription.delta"
|
||||
)
|
||||
assert isinstance(OpenAIRealtimeMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- _create_wav_header ---
|
||||
|
||||
|
||||
def test_wav_header_is_44_bytes() -> None:
|
||||
assert len(_create_wav_header(1000)) == 44
|
||||
|
||||
|
||||
def test_wav_header_chunk_size_matches_data_length() -> None:
|
||||
data_length = 2000
|
||||
header = _create_wav_header(data_length)
|
||||
chunk_size = struct.unpack_from("<I", header, 4)[0]
|
||||
assert chunk_size == 36 + data_length
|
||||
|
||||
|
||||
def test_wav_header_byte_rate() -> None:
|
||||
header = _create_wav_header(100, sample_rate=24000, channels=1, bits_per_sample=16)
|
||||
byte_rate = struct.unpack_from("<I", header, 28)[0]
|
||||
assert byte_rate == 24000 * 1 * 16 // 8
|
||||
|
||||
|
||||
def test_wav_header_produces_valid_wav() -> None:
|
||||
"""Header + PCM data should parse as valid WAV."""
|
||||
data_length = 100
|
||||
pcm_data = b"\x00" * data_length
|
||||
header = _create_wav_header(data_length, sample_rate=24000)
|
||||
|
||||
with wave.open(io.BytesIO(header + pcm_data), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 24000
|
||||
assert wav_file.getnframes() == data_length // 2
|
||||
|
||||
|
||||
# --- Provider Defaults ---
|
||||
|
||||
|
||||
def test_provider_default_models() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
assert provider.stt_model == "whisper-1"
|
||||
assert provider.tts_model == "tts-1"
|
||||
assert provider.default_voice == "alloy"
|
||||
|
||||
|
||||
def test_provider_custom_models() -> None:
|
||||
provider = OpenAIVoiceProvider(
|
||||
api_key="test",
|
||||
stt_model="gpt-4o-transcribe",
|
||||
tts_model="tts-1-hd",
|
||||
default_voice="nova",
|
||||
)
|
||||
assert provider.stt_model == "gpt-4o-transcribe"
|
||||
assert provider.tts_model == "tts-1-hd"
|
||||
assert provider.default_voice == "nova"
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -35,6 +35,7 @@ backend = [
|
||||
"alembic==1.10.4",
|
||||
"asyncpg==0.30.0",
|
||||
"atlassian-python-api==3.41.16",
|
||||
"azure-cognitiveservices-speech==1.38.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"boto3==1.39.11",
|
||||
"boto3-stubs[s3]==1.39.11",
|
||||
|
||||
39
uv.lock
generated
39
uv.lock
generated
@@ -463,6 +463,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-cognitiveservices-speech"
|
||||
version = "1.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/f4/4571c42cb00f8af317d5431f594b4ece1fbe59ab59f106947fea8e90cf89/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:18dce915ab032711f687abb3297dd19176b9cbea562b322ee6fa7365ef4a5091", size = 6775838, upload-time = "2024-06-11T03:08:35.202Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/22/0ca2c59a573119950cad1f53531fec9872fc38810c405a4e1827f3d13a8e/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9dd0800fbc4a8438c6dfd5747a658251914fe2d205a29e9b46158cadac6ab381", size = 6687975, upload-time = "2024-06-11T03:08:38.797Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/96/5436c09de3af3a9aefaa8cc00533c3a0f5d17aef5bbe017c17f0a30ad66e/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:1c344e8a6faadb063cea451f0301e13b44d9724e1242337039bff601e81e6f86", size = 40022287, upload-time = "2024-06-11T03:08:16.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/2d/ba20d05ff77ec9870cd489e6e7a474ba7fe820524bcf6fd202025e0c11cf/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1e002595a749471efeac3a54c80097946570b76c13049760b97a4b881d9d24af", size = 39788653, upload-time = "2024-06-11T03:08:30.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/21/25f8c37fb6868db4346ca977c287ede9e87f609885d932653243c9ed5f63/azure_cognitiveservices_speech-1.38.0-py3-none-win32.whl", hash = "sha256:16a530e6c646eb49ea0bc05cb45a9d28b99e4b67613f6c3a6c54e26e6bf65241", size = 1428364, upload-time = "2024-06-11T03:08:03.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/05/a6414a3481c5ee30c4f32742abe055e5f3ce4ff69e936089d86ece354067/azure_cognitiveservices_speech-1.38.0-py3-none-win_amd64.whl", hash = "sha256:1d38d8c056fb3f513a9ff27ab4e77fd08ca487f8788cc7a6df772c1ab2c97b54", size = 1539297, upload-time = "2024-06-11T03:08:01.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
@@ -4227,6 +4240,7 @@ backend = [
|
||||
{ name = "asana" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "atlassian-python-api" },
|
||||
{ name = "azure-cognitiveservices-speech" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "boto3" },
|
||||
{ name = "boto3-stubs", extra = ["s3"] },
|
||||
@@ -4381,6 +4395,7 @@ requires-dist = [
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
@@ -7233,19 +7248,21 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.5.5"
|
||||
version = "6.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
20
web/lib/opal/src/icons/audio.tsx
Normal file
20
web/lib/opal/src/icons/audio.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAudio = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 10V6M5 14V2M11 11V5M14 9V7M8 10V6"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAudio;
|
||||
@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
|
||||
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
|
||||
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudio } from "@opal/icons/audio";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
@@ -106,6 +107,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
|
||||
export { default as SvgMcp } from "@opal/icons/mcp";
|
||||
export { default as SvgMenu } from "@opal/icons/menu";
|
||||
export { default as SvgMicrophone } from "@opal/icons/microphone";
|
||||
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
|
||||
export { default as SvgMinus } from "@opal/icons/minus";
|
||||
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
@@ -176,6 +179,8 @@ export { default as SvgUserManage } from "@opal/icons/user-manage";
|
||||
export { default as SvgUserPlus } from "@opal/icons/user-plus";
|
||||
export { default as SvgUserSync } from "@opal/icons/user-sync";
|
||||
export { default as SvgUsers } from "@opal/icons/users";
|
||||
export { default as SvgVolume } from "@opal/icons/volume";
|
||||
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
|
||||
export { default as SvgWallet } from "@opal/icons/wallet";
|
||||
export { default as SvgWorkflow } from "@opal/icons/workflow";
|
||||
export { default as SvgX } from "@opal/icons/x";
|
||||
|
||||
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
{/* Microphone body */}
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
{/* Diagonal slash */}
|
||||
<path
|
||||
d="M2 2L14 14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophoneOff;
|
||||
21
web/lib/opal/src/icons/microphone.tsx
Normal file
21
web/lib/opal/src/icons/microphone.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophone = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophone;
|
||||
26
web/lib/opal/src/icons/volume-off.tsx
Normal file
26
web/lib/opal/src/icons/volume-off.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M14 6L11 9M11 6L14 9"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolumeOff;
|
||||
26
web/lib/opal/src/icons/volume.tsx
Normal file
26
web/lib/opal/src/icons/volume.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolume = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolume;
|
||||
@@ -59,7 +59,7 @@ const nextConfig = {
|
||||
{
|
||||
key: "Permissions-Policy",
|
||||
value:
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
4
web/public/ElevenLabs.svg
Normal file
4
web/public/ElevenLabs.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="currentColor"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 206 B |
4
web/public/ElevenLabsDark.svg
Normal file
4
web/public/ElevenLabsDark.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="white"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 192 B |
@@ -0,0 +1,507 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { FunctionComponent, useState, useEffect } from "react";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { Vertical, Horizontal } from "@/layouts/input-layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgArrowExchange, SvgOnyxLogo } from "@opal/icons";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { VoiceProviderView } from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
testVoiceProvider,
|
||||
upsertVoiceProvider,
|
||||
fetchVoicesByType,
|
||||
fetchLLMProviders,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
|
||||
interface VoiceOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface LLMProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
}
|
||||
|
||||
interface ApiKeyOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface VoiceProviderSetupModalProps {
|
||||
providerType: string;
|
||||
existingProvider: VoiceProviderView | null;
|
||||
mode: "stt" | "tts";
|
||||
defaultModelId?: string | null;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
const PROVIDER_LABELS: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure Speech Services",
|
||||
elevenlabs: "ElevenLabs",
|
||||
};
|
||||
|
||||
const PROVIDER_API_KEY_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/api-keys",
|
||||
azure: "https://portal.azure.com/",
|
||||
elevenlabs: "https://elevenlabs.io/app/settings/api-keys",
|
||||
};
|
||||
|
||||
const PROVIDER_LOGO_URLS: Record<string, string> = {
|
||||
openai: "/Openai.svg",
|
||||
azure: "/Azure.png",
|
||||
elevenlabs: "/ElevenLabs.svg",
|
||||
};
|
||||
|
||||
const PROVIDER_DOCS_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/docs/guides/text-to-speech",
|
||||
azure: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
elevenlabs: "https://elevenlabs.io/docs",
|
||||
};
|
||||
|
||||
const PROVIDER_VOICE_DOCS_URLS: Record<string, { url: string; label: string }> =
|
||||
{
|
||||
openai: {
|
||||
url: "https://platform.openai.com/docs/guides/text-to-speech#voice-options",
|
||||
label: "OpenAI",
|
||||
},
|
||||
azure: {
|
||||
url: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts",
|
||||
label: "Azure",
|
||||
},
|
||||
elevenlabs: {
|
||||
url: "https://elevenlabs.io/docs/voices/premade-voices",
|
||||
label: "ElevenLabs",
|
||||
},
|
||||
};
|
||||
|
||||
const OPENAI_STT_MODELS = [{ id: "whisper-1", name: "Whisper v1" }];
|
||||
|
||||
const OPENAI_TTS_MODELS = [
|
||||
{ id: "tts-1", name: "TTS-1" },
|
||||
{ id: "tts-1-hd", name: "TTS-1 HD" },
|
||||
];
|
||||
|
||||
// Map model IDs from cards to actual API model IDs
|
||||
const MODEL_ID_MAP: Record<string, string> = {
|
||||
"tts-1": "tts-1",
|
||||
"tts-1-hd": "tts-1-hd",
|
||||
whisper: "whisper-1",
|
||||
};
|
||||
|
||||
export default function VoiceProviderSetupModal({
|
||||
providerType,
|
||||
existingProvider,
|
||||
mode,
|
||||
defaultModelId,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: VoiceProviderSetupModalProps) {
|
||||
// Map the card model ID to the actual API model ID
|
||||
// Prioritize defaultModelId (from the clicked card) over stored value
|
||||
const initialTtsModel = defaultModelId
|
||||
? MODEL_ID_MAP[defaultModelId] ?? "tts-1"
|
||||
: existingProvider?.tts_model ?? "tts-1";
|
||||
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiKeyChanged, setApiKeyChanged] = useState(false);
|
||||
const [targetUri, setTargetUri] = useState(
|
||||
existingProvider?.target_uri ?? ""
|
||||
);
|
||||
const [selectedLlmProviderId, setSelectedLlmProviderId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
const [sttModel, setSttModel] = useState(
|
||||
existingProvider?.stt_model ?? "whisper-1"
|
||||
);
|
||||
const [ttsModel, setTtsModel] = useState(initialTtsModel);
|
||||
const [defaultVoice, setDefaultVoice] = useState(
|
||||
existingProvider?.default_voice ?? ""
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
// Dynamic voices fetched from backend
|
||||
const [voiceOptions, setVoiceOptions] = useState<VoiceOption[]>([]);
|
||||
const [isLoadingVoices, setIsLoadingVoices] = useState(false);
|
||||
|
||||
// Existing OpenAI LLM providers for API key reuse
|
||||
const [existingApiKeyOptions, setExistingApiKeyOptions] = useState<
|
||||
ApiKeyOption[]
|
||||
>([]);
|
||||
const [llmProviderMap, setLlmProviderMap] = useState<Map<string, number>>(
|
||||
new Map()
|
||||
);
|
||||
|
||||
// Fetch existing OpenAI LLM providers (for API key reuse)
|
||||
useEffect(() => {
|
||||
if (providerType !== "openai") return;
|
||||
|
||||
fetchLLMProviders()
|
||||
.then((res) => res.json())
|
||||
.then((data: LLMProviderView[]) => {
|
||||
const openaiProviders = data.filter(
|
||||
(p) => p.provider === "openai" && p.api_key
|
||||
);
|
||||
const options: ApiKeyOption[] = openaiProviders.map((p) => ({
|
||||
value: p.api_key!,
|
||||
label: p.api_key!,
|
||||
description: `Used for LLM provider ${p.name}`,
|
||||
}));
|
||||
setExistingApiKeyOptions(options);
|
||||
|
||||
// Map masked API keys to provider IDs for lookup on selection
|
||||
const providerMap = new Map<string, number>();
|
||||
openaiProviders.forEach((p) => {
|
||||
if (p.api_key) {
|
||||
providerMap.set(p.api_key, p.id);
|
||||
}
|
||||
});
|
||||
setLlmProviderMap(providerMap);
|
||||
})
|
||||
.catch(() => {
|
||||
setExistingApiKeyOptions([]);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
// Fetch voices on mount (works without API key for ElevenLabs/OpenAI)
|
||||
useEffect(() => {
|
||||
setIsLoadingVoices(true);
|
||||
fetchVoicesByType(providerType)
|
||||
.then((res) => res.json())
|
||||
.then((data: Array<{ id: string; name: string }>) => {
|
||||
const options = data.map((v) => ({
|
||||
value: v.id,
|
||||
label: v.name,
|
||||
description: v.id,
|
||||
}));
|
||||
setVoiceOptions(options);
|
||||
// Set default voice to first option if not already set,
|
||||
// or if current value doesn't exist in the new options
|
||||
setDefaultVoice((prev) => {
|
||||
if (!prev) return options[0]?.value ?? "";
|
||||
const existsInOptions = options.some((opt) => opt.value === prev);
|
||||
return existsInOptions ? prev : options[0]?.value ?? "";
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
setVoiceOptions([]);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingVoices(false);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
const isEditing = !!existingProvider;
|
||||
const label = PROVIDER_LABELS[providerType] ?? providerType;
|
||||
|
||||
// Logo arrangement component for the modal header
|
||||
// No useMemo needed - providerType and label are stable props
|
||||
const LogoArrangement: FunctionComponent<IconProps> = () => (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex items-center justify-center size-7 shrink-0 overflow-clip">
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={24} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={24} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={24} />
|
||||
) : (
|
||||
<Image
|
||||
src={PROVIDER_LOGO_URLS[providerType] ?? "/Openai.svg"}
|
||||
alt={`${label} logo`}
|
||||
width={24}
|
||||
height={24}
|
||||
className="object-contain"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-4 shrink-0">
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
// API key required for new providers, or when explicitly changed during edit
|
||||
if (!selectedLlmProviderId) {
|
||||
if (!isEditing && !apiKey) {
|
||||
toast.error("API key is required");
|
||||
return;
|
||||
}
|
||||
if (isEditing && apiKeyChanged && !apiKey) {
|
||||
toast.error(
|
||||
"API key cannot be empty. Leave blank to keep existing key."
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (providerType === "azure" && !isEditing && !targetUri) {
|
||||
toast.error("Target URI is required");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
// Test the connection first (skip if reusing LLM provider key - it's already validated)
|
||||
if (!selectedLlmProviderId) {
|
||||
const testResponse = await testVoiceProvider({
|
||||
provider_type: providerType,
|
||||
api_key: apiKeyChanged ? apiKey : undefined,
|
||||
target_uri: targetUri || undefined,
|
||||
use_stored_key: isEditing && !apiKeyChanged,
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
const data = await testResponse.json();
|
||||
toast.error(data.detail || "Connection test failed");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Save the provider
|
||||
const response = await upsertVoiceProvider({
|
||||
id: existingProvider?.id,
|
||||
name: label,
|
||||
provider_type: providerType,
|
||||
api_key: selectedLlmProviderId
|
||||
? undefined
|
||||
: apiKeyChanged
|
||||
? apiKey
|
||||
: undefined,
|
||||
api_key_changed: selectedLlmProviderId ? false : apiKeyChanged,
|
||||
target_uri: targetUri || undefined,
|
||||
llm_provider_id: selectedLlmProviderId,
|
||||
stt_model: sttModel,
|
||||
tts_model: ttsModel,
|
||||
default_voice: defaultVoice,
|
||||
activate_stt: mode === "stt",
|
||||
activate_tts: mode === "tts",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
toast.success(isEditing ? "Provider updated" : "Provider connected");
|
||||
onSuccess();
|
||||
} else {
|
||||
const data = await response.json();
|
||||
toast.error(data.detail || "Failed to save provider");
|
||||
}
|
||||
} catch {
|
||||
toast.error("Failed to save provider");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={LogoArrangement}
|
||||
title={isEditing ? `Edit ${label}` : `Set up ${label}`}
|
||||
description={`Connect to ${label} and set up your voice models.`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<Vertical
|
||||
title="API Key"
|
||||
subDescription={
|
||||
isEditing ? (
|
||||
"Leave blank to keep existing key"
|
||||
) : (
|
||||
<>
|
||||
Paste your{" "}
|
||||
<a
|
||||
href={PROVIDER_API_KEY_URLS[providerType]}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
API key
|
||||
</a>{" "}
|
||||
from {label} to access your models.
|
||||
</>
|
||||
)
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
{providerType === "openai" && existingApiKeyOptions.length > 0 ? (
|
||||
<InputComboBox
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
setSelectedLlmProviderId(null);
|
||||
}}
|
||||
onValueChange={(value) => {
|
||||
setApiKey(value);
|
||||
// Check if this is an existing key
|
||||
const llmProviderId = llmProviderMap.get(value);
|
||||
if (llmProviderId) {
|
||||
setSelectedLlmProviderId(llmProviderId);
|
||||
setApiKeyChanged(false);
|
||||
} else {
|
||||
setSelectedLlmProviderId(null);
|
||||
setApiKeyChanged(true);
|
||||
}
|
||||
}}
|
||||
options={existingApiKeyOptions}
|
||||
separatorLabel="Reuse OpenAI API Keys"
|
||||
strict={false}
|
||||
showAddPrefix
|
||||
/>
|
||||
) : (
|
||||
<InputTypeIn
|
||||
type="password"
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Vertical>
|
||||
|
||||
{providerType === "azure" && (
|
||||
<Vertical
|
||||
title="Target URI"
|
||||
subDescription={
|
||||
<>
|
||||
Paste the endpoint shown in{" "}
|
||||
<a
|
||||
href="https://portal.azure.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Azure Portal (Keys and Endpoint)
|
||||
</a>
|
||||
. Onyx extracts the speech region from this URL. Examples:
|
||||
https://westus.api.cognitive.microsoft.com/ or
|
||||
https://westus.tts.speech.microsoft.com/.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
placeholder={
|
||||
isEditing
|
||||
? "Leave blank to keep existing"
|
||||
: "https://<region>.api.cognitive.microsoft.com/"
|
||||
}
|
||||
value={targetUri}
|
||||
onChange={(e) => setTargetUri(e.target.value)}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "stt" && (
|
||||
<Horizontal title="STT Model" center nonInteractive>
|
||||
<InputSelect value={sttModel} onValueChange={setSttModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_STT_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Horizontal>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "tts" && (
|
||||
<Vertical
|
||||
title="Default Model"
|
||||
subDescription="This model will be used by Onyx by default for text-to-speech."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect value={ttsModel} onValueChange={setTtsModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_TTS_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{mode === "tts" && (
|
||||
<Vertical
|
||||
title="Voice"
|
||||
subDescription={
|
||||
<>
|
||||
This voice will be used for spoken responses. See full list
|
||||
of supported languages and voices at{" "}
|
||||
<a
|
||||
href={
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
|
||||
PROVIDER_DOCS_URLS[providerType]
|
||||
}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
{PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label}
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputComboBox
|
||||
value={defaultVoice}
|
||||
onValueChange={setDefaultVoice}
|
||||
options={voiceOptions}
|
||||
placeholder={
|
||||
isLoadingVoices
|
||||
? "Loading voices..."
|
||||
: "Select a voice or enter voice ID"
|
||||
}
|
||||
disabled={isLoadingVoices}
|
||||
strict={false}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button secondary onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSubmit} disabled={isSubmitting}>
|
||||
{isSubmitting ? "Connecting..." : isEditing ? "Save" : "Connect"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
630
web/src/app/admin/configuration/voice/page.tsx
Normal file
630
web/src/app/admin/configuration/voice/page.tsx
Normal file
@@ -0,0 +1,630 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
InfoIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import {
|
||||
useVoiceProviders,
|
||||
VoiceProviderView,
|
||||
} from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
activateVoiceProvider,
|
||||
deactivateVoiceProvider,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgAudio,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgMicrophone,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import VoiceProviderSetupModal from "./VoiceProviderSetupModal";
|
||||
|
||||
interface ModelDetails {
|
||||
id: string;
|
||||
label: string;
|
||||
subtitle: string;
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
interface ProviderGroup {
|
||||
providerType: string;
|
||||
providerLabel: string;
|
||||
logoSrc?: string;
|
||||
models: ModelDetails[];
|
||||
}
|
||||
|
||||
// STT Models - individual cards
|
||||
const STT_MODELS: ModelDetails[] = [
|
||||
{
|
||||
id: "whisper",
|
||||
label: "Whisper",
|
||||
subtitle: "OpenAI's general purpose speech recognition model.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "azure-speech-stt",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Speech to text in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
{
|
||||
id: "elevenlabs-stt",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Speech to Text API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
];
|
||||
|
||||
// TTS Models - grouped by provider
|
||||
const TTS_PROVIDER_GROUPS: ProviderGroup[] = [
|
||||
{
|
||||
providerType: "openai",
|
||||
providerLabel: "OpenAI",
|
||||
logoSrc: "/Openai.svg",
|
||||
models: [
|
||||
{
|
||||
id: "tts-1",
|
||||
label: "TTS-1",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for speed.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "tts-1-hd",
|
||||
label: "TTS-1 HD",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for quality.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "azure",
|
||||
providerLabel: "Azure",
|
||||
logoSrc: "/Azure.png",
|
||||
models: [
|
||||
{
|
||||
id: "azure-speech-tts",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Text to speech in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "elevenlabs",
|
||||
providerLabel: "ElevenLabs",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
models: [
|
||||
{
|
||||
id: "elevenlabs-tts",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Text to Speech API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
const [editingProvider, setEditingProvider] =
|
||||
useState<VoiceProviderView | null>(null);
|
||||
const [modalMode, setModalMode] = useState<ProviderMode>("stt");
|
||||
const [selectedModelId, setSelectedModelId] = useState<string | null>(null);
|
||||
const [sttActivationError, setSTTActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
|
||||
|
||||
const handleConnect = (
|
||||
providerType: string,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(providerType);
|
||||
setEditingProvider(null);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
setSTTActivationError(null);
|
||||
setTTSActivationError(null);
|
||||
};
|
||||
|
||||
const handleEdit = (
|
||||
provider: VoiceProviderView,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(provider.provider_type);
|
||||
setEditingProvider(provider);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const handleSetDefault = async (
|
||||
providerId: number,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await activateVoiceProvider(providerId, mode, modelId);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to set provider as default ${mode.toUpperCase()}.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeactivate = async (providerId: number, mode: ProviderMode) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await deactivateVoiceProvider(providerId, mode);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to deactivate ${mode.toUpperCase()} provider.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalClose = () => {
|
||||
setModalOpen(false);
|
||||
setSelectedProvider(null);
|
||||
setEditingProvider(null);
|
||||
setSelectedModelId(null);
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
mutate();
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
|
||||
// Map provider types to their configured provider data
|
||||
const providersByType = useMemo(() => {
|
||||
return new Map((providers ?? []).map((p) => [p.provider_type, p] as const));
|
||||
}, [providers]);
|
||||
|
||||
const hasActiveSTTProvider =
|
||||
providers?.some((p) => p.is_default_stt) ?? false;
|
||||
const hasActiveTTSProvider =
|
||||
providers?.some((p) => p.is_default_tts) ?? false;
|
||||
|
||||
const renderLogo = ({
|
||||
logoSrc,
|
||||
providerType,
|
||||
alt,
|
||||
size = 16,
|
||||
}: {
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
alt: string;
|
||||
size?: number;
|
||||
}) => {
|
||||
const containerSizeClass = size === 24 ? "size-7" : "size-5";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center px-0.5 py-0 shrink-0 overflow-clip",
|
||||
containerSizeClass
|
||||
)}
|
||||
>
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={size} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={size} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={size} />
|
||||
) : logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={alt}
|
||||
width={size}
|
||||
height={size}
|
||||
className="object-contain"
|
||||
/>
|
||||
) : (
|
||||
<SvgMicrophone size={size} className="text-text-02" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderModelCard = ({
|
||||
model,
|
||||
mode,
|
||||
}: {
|
||||
model: ModelDetails;
|
||||
mode: ProviderMode;
|
||||
}) => {
|
||||
const provider = providersByType.get(model.providerType);
|
||||
const isConfigured = isProviderConfigured(provider);
|
||||
// For TTS, also check that this specific model is the default (not just the provider)
|
||||
const isActive =
|
||||
mode === "stt"
|
||||
? provider?.is_default_stt
|
||||
: provider?.is_default_tts && provider?.tts_model === model.id;
|
||||
const isHighlighted = isActive ?? false;
|
||||
const providerId = provider?.id;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: () => handleConnect(model.providerType, mode, model.id),
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
? () => handleDeactivate(providerId, mode)
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => handleSetDefault(providerId, mode, model.id)
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `${mode}-${model.id}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${mode}-${model.id}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-4 rounded-16 border p-2 bg-background-neutral-01",
|
||||
isHighlighted ? "border-action-link-05" : "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-2.5 p-2">
|
||||
{renderLogo({
|
||||
logoSrc: model.logoSrc,
|
||||
providerType: model.providerType,
|
||||
alt: `${model.label} logo`,
|
||||
size: 16,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text04>
|
||||
{model.label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{model.subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-1.5 self-center">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
aria-label={`Edit ${model.label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Button
|
||||
action={false}
|
||||
tertiary
|
||||
disabled={buttonState.disabled || !buttonState.onClick}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (error) {
|
||||
const message = error?.message || "Unable to load voice configuration.";
|
||||
const detail =
|
||||
error instanceof FetchError && typeof error.info?.detail === "string"
|
||||
? error.info.detail
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgAudio} title="Voice" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" secondaryBody text03>
|
||||
Speech to text (STT) and text to speech (TTS) capabilities.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
{/* Speech-to-Text Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Speech to Text
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to transcribe speech to text in chats.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{sttActivationError && (
|
||||
<Callout type="danger" title="Unable to update STT provider">
|
||||
{sttActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveSTTProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a speech to text provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{STT_MODELS.map((model) => renderModelCard({ model, mode: "stt" }))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Text-to-Speech Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Text to Speech
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to speak out chat responses.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{ttsActivationError && (
|
||||
<Callout type="danger" title="Unable to update TTS provider">
|
||||
{ttsActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveTTSProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a text to speech provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-4">
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text as="p" secondaryBody text03 className="px-0.5">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
{group.models.map((model) =>
|
||||
renderModelCard({ model, mode: "tts" })
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
existingProvider={editingProvider}
|
||||
mode={modalMode}
|
||||
defaultModelId={selectedModelId}
|
||||
onClose={handleModalClose}
|
||||
onSuccess={handleModalSuccess}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import type { Route } from "next";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { requireAuth } from "@/lib/auth/requireAuth";
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
|
||||
export interface LayoutProps {
|
||||
@@ -21,10 +22,15 @@ export default async function Layout({ children }: LayoutProps) {
|
||||
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
{/* VoiceModeProvider wraps the full app layout so TTS playback state
|
||||
persists across page navigations (e.g., sidebar clicks during playback).
|
||||
It only activates WebSocket connections when TTS is actually triggered. */}
|
||||
<VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</VoiceModeProvider>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import React, { useRef, RefObject, useMemo } from "react";
|
||||
import React, {
|
||||
useRef,
|
||||
RefObject,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
} from "react";
|
||||
import { Packet, StopReason } from "@/app/app/services/streamingModels";
|
||||
import CustomToolAuthCard from "@/app/app/message/messageComponents/CustomToolAuthCard";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
@@ -16,6 +22,9 @@ import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { getTextContent } from "@/app/app/services/packetUtils";
|
||||
import { removeThinkingTokens } from "@/app/app/services/thinkingTokens";
|
||||
|
||||
// Type for the regeneration factory function passed from ChatUI
|
||||
export type RegenerationFactory = (regenerationRequest: {
|
||||
@@ -75,6 +84,7 @@ function arePropsEqual(
|
||||
|
||||
const AgentMessage = React.memo(function AgentMessage({
|
||||
rawPackets,
|
||||
packetCount,
|
||||
chatState,
|
||||
nodeId,
|
||||
messageId,
|
||||
@@ -162,6 +172,59 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onMessageSelection,
|
||||
});
|
||||
|
||||
// Streaming TTS integration
|
||||
const { streamTTS, resetTTS, stopTTS } = useVoiceMode();
|
||||
const ttsCompletedRef = useRef(false);
|
||||
const streamTTSRef = useRef(streamTTS);
|
||||
|
||||
// Keep streamTTS ref in sync without triggering effect re-runs
|
||||
useEffect(() => {
|
||||
streamTTSRef.current = streamTTS;
|
||||
}, [streamTTS]);
|
||||
|
||||
// Stream TTS as text content arrives - only for messages still streaming
|
||||
// Uses ref for streamTTS to avoid re-triggering when its identity changes
|
||||
// Note: packetCount is used instead of rawPackets because the array is mutated in place
|
||||
useLayoutEffect(() => {
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const textContent = removeThinkingTokens(getTextContent(rawPackets));
|
||||
if (typeof textContent === "string" && textContent.length > 0) {
|
||||
streamTTSRef.current(textContent, isComplete, nodeId);
|
||||
|
||||
// Mark as completed once the message is done streaming
|
||||
if (isComplete) {
|
||||
ttsCompletedRef.current = true;
|
||||
}
|
||||
}
|
||||
}, [packetCount, isComplete, rawPackets, nodeId, stopPacketSeen, stopReason]); // packetCount triggers on new packets since rawPackets is mutated in place
|
||||
|
||||
// Stop TTS immediately when user cancels generation.
|
||||
useEffect(() => {
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
stopTTS({ manual: true });
|
||||
}
|
||||
}, [stopPacketSeen, stopReason, stopTTS]);
|
||||
|
||||
// Reset TTS completed flag when nodeId changes (new message)
|
||||
useEffect(() => {
|
||||
ttsCompletedRef.current = false;
|
||||
}, [nodeId]);
|
||||
|
||||
// Reset TTS when component unmounts or nodeId changes
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
resetTTS();
|
||||
};
|
||||
}, [nodeId, resetTTS]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col gap-3"
|
||||
@@ -208,6 +271,8 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
key={`${displayGroup.turn_index}-${displayGroup.tab_index}`}
|
||||
packets={displayGroup.packets}
|
||||
chatState={effectiveChatState}
|
||||
messageNodeId={nodeId}
|
||||
hasTimelineThinking={pacedTurnGroups.length > 0 || hasSteps}
|
||||
onComplete={() => {
|
||||
// Only mark complete on the last display group
|
||||
// Hook handles the finalAnswerComing check internally
|
||||
|
||||
@@ -29,6 +29,9 @@ import FeedbackModal, {
|
||||
FeedbackModalProps,
|
||||
} from "@/sections/modals/FeedbackModal";
|
||||
import { Button, SelectButton } from "@opal/components";
|
||||
import TTSButton from "./TTSButton";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { useVoiceStatus } from "@/hooks/useVoiceStatus";
|
||||
|
||||
// Wrapper component for SourceTag in toolbar to handle memoization
|
||||
const SourcesTagWrapper = React.memo(function SourcesTagWrapper({
|
||||
@@ -144,6 +147,14 @@ export default function MessageToolbar({
|
||||
(state) => state.updateCurrentSelectedNodeForDocDisplay
|
||||
);
|
||||
|
||||
// Voice mode - hide toolbar during TTS playback for this message
|
||||
const { isTTSPlaying, activeMessageNodeId, isAwaitingAutoPlaybackStart } =
|
||||
useVoiceMode();
|
||||
const { ttsEnabled } = useVoiceStatus();
|
||||
const isTTSActiveForThisMessage =
|
||||
(isTTSPlaying || isAwaitingAutoPlaybackStart) &&
|
||||
activeMessageNodeId === nodeId;
|
||||
|
||||
// Feedback modal state and handlers
|
||||
const { handleFeedbackChange } = useFeedbackController();
|
||||
const modal = useCreateModal();
|
||||
@@ -204,6 +215,11 @@ export default function MessageToolbar({
|
||||
[messageId, currentFeedback, handleFeedbackChange, modal]
|
||||
);
|
||||
|
||||
// Hide toolbar while TTS is playing for this message
|
||||
if (isTTSActiveForThisMessage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<modal.Provider>
|
||||
@@ -268,6 +284,13 @@ export default function MessageToolbar({
|
||||
}
|
||||
data-testid="AgentMessage/dislike-button"
|
||||
/>
|
||||
{ttsEnabled && (
|
||||
<TTSButton
|
||||
text={
|
||||
removeThinkingTokens(getTextContent(rawPackets)) as string
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{onRegenerate &&
|
||||
messageId !== undefined &&
|
||||
|
||||
90
web/src/app/app/message/messageComponents/TTSButton.tsx
Normal file
90
web/src/app/app/message/messageComponents/TTSButton.tsx
Normal file
@@ -0,0 +1,90 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { SvgPlayCircle, SvgStop } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { useVoicePlayback } from "@/hooks/useVoicePlayback";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
interface TTSButtonProps {
|
||||
text: string;
|
||||
voice?: string;
|
||||
speed?: number;
|
||||
}
|
||||
|
||||
function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
const { isPlaying, isLoading, error, play, pause, stop } = useVoicePlayback();
|
||||
const { isTTSPlaying, isTTSLoading, isAwaitingAutoPlaybackStart, stopTTS } =
|
||||
useVoiceMode();
|
||||
|
||||
const isGlobalTTSActive =
|
||||
isTTSPlaying || isTTSLoading || isAwaitingAutoPlaybackStart;
|
||||
const isButtonPlaying = isGlobalTTSActive || isPlaying;
|
||||
const isButtonLoading = !isGlobalTTSActive && isLoading;
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isGlobalTTSActive) {
|
||||
// Stop auto-playback voice mode stream from the toolbar button.
|
||||
stopTTS({ manual: true });
|
||||
stop();
|
||||
} else if (isPlaying) {
|
||||
pause();
|
||||
} else if (isButtonLoading) {
|
||||
stop();
|
||||
} else {
|
||||
try {
|
||||
// Ensure no voice-mode stream is active before starting manual playback.
|
||||
stopTTS();
|
||||
await play(text, voice, speed);
|
||||
} catch (err) {
|
||||
console.error("TTS playback failed:", err);
|
||||
toast.error("Could not play audio");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isGlobalTTSActive,
|
||||
isPlaying,
|
||||
isButtonLoading,
|
||||
text,
|
||||
voice,
|
||||
speed,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
stopTTS,
|
||||
]);
|
||||
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
const icon = isButtonLoading
|
||||
? SimpleLoader
|
||||
: isButtonPlaying
|
||||
? SvgStop
|
||||
: SvgPlayCircle;
|
||||
|
||||
const tooltip = isButtonPlaying
|
||||
? "Stop playback"
|
||||
: isButtonLoading
|
||||
? "Loading..."
|
||||
: "Read aloud";
|
||||
|
||||
return (
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
prominence="tertiary"
|
||||
tooltip={tooltip}
|
||||
data-testid="AgentMessage/tts-button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TTSButton;
|
||||
@@ -67,6 +67,10 @@ export type MessageRenderer<
|
||||
> = React.ComponentType<{
|
||||
packets: T[];
|
||||
state: S;
|
||||
/** Node id for the message currently being rendered */
|
||||
messageNodeId?: number;
|
||||
/** True when timeline/thinking UI is already shown above this text block */
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
renderType: RenderType;
|
||||
animate: boolean;
|
||||
|
||||
@@ -166,6 +166,8 @@ function MixedContentHandler({
|
||||
chatPackets,
|
||||
imagePackets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -175,6 +177,8 @@ function MixedContentHandler({
|
||||
chatPackets: Packet[];
|
||||
imagePackets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -185,6 +189,8 @@ function MixedContentHandler({
|
||||
<MessageTextRenderer
|
||||
packets={chatPackets as ChatPacket[]}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={() => {}}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
@@ -212,6 +218,8 @@ function MixedContentHandler({
|
||||
interface RendererComponentProps {
|
||||
packets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -229,7 +237,8 @@ function areRendererPropsEqual(
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id &&
|
||||
prev.messageNodeId === next.messageNodeId
|
||||
// Skip: onComplete, children (function refs), chatState (memoized upstream)
|
||||
);
|
||||
}
|
||||
@@ -238,6 +247,8 @@ function areRendererPropsEqual(
|
||||
export const RendererComponent = memo(function RendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -272,6 +283,8 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
chatPackets={chatPackets}
|
||||
imagePackets={imagePackets}
|
||||
chatState={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
@@ -292,6 +305,8 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useMemo, useState } from "react";
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
import {
|
||||
@@ -10,6 +10,55 @@ import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
markdown[mdIndex + 1] === "("
|
||||
) {
|
||||
const closeIdx = markdown.indexOf(")", mdIndex + 2);
|
||||
if (closeIdx > 0) {
|
||||
mdIndex = closeIdx + 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
cleanIndex++;
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
markdown[mdIndex] !== "\n"
|
||||
) {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
@@ -20,6 +69,8 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
> = ({
|
||||
packets,
|
||||
state,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
renderType,
|
||||
animate,
|
||||
@@ -36,6 +87,17 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
@@ -50,6 +112,11 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
})
|
||||
.join("");
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
useEffect(() => {
|
||||
if (!animate) {
|
||||
@@ -93,13 +160,37 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
|
||||
// Get content based on displayed packet count
|
||||
const content = useMemo(() => {
|
||||
// Get content based on displayed packet count or audio progress
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
// Standard behavior when auto-playback is off
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Only show content from packets up to displayedPacketCount
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
@@ -112,31 +203,109 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [animate, displayedPacketCount, fullContent, packets]);
|
||||
}, [
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel, freeze at exactly what was already visible.
|
||||
if (wasUserCancelled) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
if (computedContent.length === 0) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
const last = lastStableSyncedContentRef.current;
|
||||
if (computedContent.startsWith(last)) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [computedContent, shouldUseAutoPlaybackSync, stopPacketSeen, stopReason]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
}
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
lastStableSyncedContentRef.current = "";
|
||||
} else if (content.length > 0) {
|
||||
lastStableSyncedContentRef.current = content;
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowSpeechWarmupIndicator =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
!isAwaitingAutoPlaybackStart &&
|
||||
content.length === 0 &&
|
||||
fullContent.length > 0 &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
stopPacketSeen ? content : content + " [*]() ",
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
return children([
|
||||
{
|
||||
icon: null,
|
||||
status: null,
|
||||
content:
|
||||
content.length > 0 || packets.length > 0 ? (
|
||||
<>
|
||||
{renderedContent}
|
||||
{wasUserCancelled && (
|
||||
<Text as="p" secondaryBody text04>
|
||||
User has stopped generation
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
shouldShowThinkingPlaceholder || shouldShowSpeechWarmupIndicator ? (
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -566,6 +566,21 @@ textarea {
|
||||
animation: fadeIn 0.2s ease-out forwards;
|
||||
}
|
||||
|
||||
/* Recording waveform animation */
|
||||
@keyframes waveform {
|
||||
0%,
|
||||
100% {
|
||||
transform: scaleY(0.3);
|
||||
}
|
||||
50% {
|
||||
transform: scaleY(1);
|
||||
}
|
||||
}
|
||||
|
||||
.animate-waveform {
|
||||
animation: waveform 0.8s ease-in-out infinite;
|
||||
}
|
||||
|
||||
.container {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
@@ -11,5 +12,9 @@ export interface LayoutProps {
|
||||
* Sidebar and chrome are handled by sub-layouts / individual pages.
|
||||
*/
|
||||
export default function Layout({ children }: LayoutProps) {
|
||||
return <ProjectsProvider>{children}</ProjectsProvider>;
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
<VoiceModeProvider>{children}</VoiceModeProvider>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,6 +39,8 @@ import document360Icon from "@public/Document360.png";
|
||||
import dropboxIcon from "@public/Dropbox.png";
|
||||
import drupalwikiIcon from "@public/DrupalWiki.png";
|
||||
import egnyteIcon from "@public/Egnyte.png";
|
||||
import elevenLabsDarkSVG from "@public/ElevenLabsDark.svg";
|
||||
import elevenLabsSVG from "@public/ElevenLabs.svg";
|
||||
import firefliesIcon from "@public/Fireflies.png";
|
||||
import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
@@ -843,6 +845,9 @@ export const Document360Icon = createLogoIcon(document360Icon);
|
||||
export const DropboxIcon = createLogoIcon(dropboxIcon);
|
||||
export const DrupalWikiIcon = createLogoIcon(drupalwikiIcon);
|
||||
export const EgnyteIcon = createLogoIcon(egnyteIcon);
|
||||
export const ElevenLabsIcon = createLogoIcon(elevenLabsSVG, {
|
||||
darkSrc: elevenLabsDarkSVG,
|
||||
});
|
||||
export const FirefliesIcon = createLogoIcon(firefliesIcon);
|
||||
export const FreshdeskIcon = createLogoIcon(freshdeskIcon);
|
||||
export const GeminiIcon = createLogoIcon(geminiSVG);
|
||||
|
||||
206
web/src/components/voice/Waveform.tsx
Normal file
206
web/src/components/voice/Waveform.tsx
Normal file
@@ -0,0 +1,206 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useMemo, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { formatElapsedTime } from "@/lib/dateUtils";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgMicrophone,
|
||||
SvgMicrophoneOff,
|
||||
SvgVolume,
|
||||
SvgVolumeOff,
|
||||
} from "@opal/icons";
|
||||
|
||||
// Recording waveform constants
|
||||
const RECORDING_BAR_COUNT = 120;
|
||||
const MIN_BAR_HEIGHT = 2;
|
||||
const MAX_BAR_HEIGHT = 16;
|
||||
|
||||
// Speaking waveform constants
|
||||
const SPEAKING_BAR_COUNT = 28;
|
||||
|
||||
interface WaveformProps {
|
||||
/** Visual style and behavior variant */
|
||||
variant: "speaking" | "recording";
|
||||
/** Whether the waveform is actively animating */
|
||||
isActive: boolean;
|
||||
/** Whether audio is muted */
|
||||
isMuted?: boolean;
|
||||
/** Current microphone audio level (0-1), only used for recording variant */
|
||||
audioLevel?: number;
|
||||
/** Callback when mute button is clicked */
|
||||
onMuteToggle?: () => void;
|
||||
}
|
||||
|
||||
function Waveform({
|
||||
variant,
|
||||
isActive,
|
||||
isMuted = false,
|
||||
audioLevel = 0,
|
||||
onMuteToggle,
|
||||
}: WaveformProps) {
|
||||
// ─── Recording variant state ───────────────────────────────────────────────
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
const [barHeights, setBarHeights] = useState<number[]>(
|
||||
() => new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
const animationRef = useRef<number | null>(null);
|
||||
const lastPushTimeRef = useRef(0);
|
||||
const audioLevelRef = useRef(audioLevel);
|
||||
audioLevelRef.current = audioLevel;
|
||||
|
||||
// ─── Speaking variant bars ─────────────────────────────────────────────────
|
||||
const speakingBars = useMemo(() => {
|
||||
return Array.from({ length: SPEAKING_BAR_COUNT }, (_, i) => ({
|
||||
id: i,
|
||||
// Create a natural wave pattern with height variation
|
||||
baseHeight: Math.sin(i * 0.4) * 5 + 8,
|
||||
delay: i * 0.025,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
// ─── Recording: Timer effect ───────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setElapsedSeconds(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
setElapsedSeconds((prev) => prev + 1);
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [variant, isActive]);
|
||||
|
||||
// ─── Recording: Audio level visualization effect ───────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setBarHeights(
|
||||
new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
lastPushTimeRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateBars = (timestamp: number) => {
|
||||
// Push a new bar roughly every 50ms (~20fps scrolling)
|
||||
if (timestamp - lastPushTimeRef.current >= 50) {
|
||||
lastPushTimeRef.current = timestamp;
|
||||
const level = isMuted ? 0 : audioLevelRef.current;
|
||||
const height =
|
||||
MIN_BAR_HEIGHT + level * (MAX_BAR_HEIGHT - MIN_BAR_HEIGHT);
|
||||
|
||||
setBarHeights((prev) => {
|
||||
const next = prev.slice(1);
|
||||
next.push(height);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
};
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
animationRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [variant, isActive, isMuted]);
|
||||
|
||||
const formattedTime = useMemo(
|
||||
() => formatElapsedTime(elapsedSeconds),
|
||||
[elapsedSeconds]
|
||||
);
|
||||
|
||||
if (!isActive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// ─── Speaking variant render ───────────────────────────────────────────────
|
||||
if (variant === "speaking") {
|
||||
return (
|
||||
<div className="flex items-center gap-0.5 p-1.5 bg-background-tint-00 rounded-16 shadow-01">
|
||||
{/* Waveform container */}
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12 max-w-[144px] min-h-[32px]">
|
||||
<div className="flex items-center p-1">
|
||||
{/* Waveform bars */}
|
||||
<div className="flex items-center justify-center gap-[2px] h-4 w-[120px] overflow-hidden">
|
||||
{speakingBars.map((bar) => (
|
||||
<div
|
||||
key={bar.id}
|
||||
className={cn(
|
||||
"w-[3px] rounded-full",
|
||||
isMuted ? "bg-text-03" : "bg-theme-blue-05",
|
||||
!isMuted && "animate-waveform"
|
||||
)}
|
||||
style={{
|
||||
height: isMuted ? "2px" : `${bar.baseHeight}px`,
|
||||
animationDelay: isMuted ? undefined : `${bar.delay}s`,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="w-0.5 self-stretch bg-border-02" />
|
||||
|
||||
{/* Volume button */}
|
||||
{onMuteToggle && (
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12">
|
||||
<Button
|
||||
icon={isMuted ? SvgVolumeOff : SvgVolume}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
tooltip={isMuted ? "Unmute" : "Mute"}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Recording variant render ──────────────────────────────────────────────
|
||||
return (
|
||||
<div className="flex items-center gap-3 px-3 py-2 bg-background-tint-00 rounded-12 min-h-[32px]">
|
||||
{/* Waveform visualization driven by real audio levels */}
|
||||
<div className="flex-1 flex items-center justify-between h-4 overflow-hidden">
|
||||
{barHeights.map((height, i) => (
|
||||
<div
|
||||
key={i}
|
||||
className="w-[1.5px] bg-text-03 rounded-full shrink-0 transition-[height] duration-75"
|
||||
style={{ height: `${height}px` }}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Timer */}
|
||||
<span className="font-mono text-xs text-text-03 tabular-nums shrink-0">
|
||||
{formattedTime}
|
||||
</span>
|
||||
|
||||
{/* Mute button */}
|
||||
{onMuteToggle && (
|
||||
<Button
|
||||
icon={isMuted ? SvgMicrophoneOff : SvgMicrophone}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
aria-label={isMuted ? "Unmute microphone" : "Mute microphone"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Waveform;
|
||||
107
web/src/hooks/useVoicePlayback.ts
Normal file
107
web/src/hooks/useVoicePlayback.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
import { StreamingTTSPlayer } from "@/lib/streamingTTS";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
export interface UseVoicePlaybackReturn {
|
||||
isPlaying: boolean;
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
play: (text: string, voice?: string, speed?: number) => Promise<void>;
|
||||
pause: () => void;
|
||||
stop: () => void;
|
||||
}
|
||||
|
||||
export function useVoicePlayback(): UseVoicePlaybackReturn {
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const playerRef = useRef<StreamingTTSPlayer | null>(null);
|
||||
const suppressPlayerErrorsRef = useRef(false);
|
||||
const { setManualTTSPlaying, isTTSMuted, registerManualTTSMuteHandler } =
|
||||
useVoiceMode();
|
||||
|
||||
useEffect(() => {
|
||||
registerManualTTSMuteHandler((muted) => {
|
||||
playerRef.current?.setMuted(muted);
|
||||
});
|
||||
return () => {
|
||||
registerManualTTSMuteHandler(null);
|
||||
};
|
||||
}, [registerManualTTSMuteHandler]);
|
||||
|
||||
const stop = useCallback(() => {
|
||||
suppressPlayerErrorsRef.current = true;
|
||||
if (playerRef.current) {
|
||||
playerRef.current.stop();
|
||||
playerRef.current = null;
|
||||
}
|
||||
setManualTTSPlaying(false);
|
||||
setError(null);
|
||||
setIsPlaying(false);
|
||||
setIsLoading(false);
|
||||
}, [setManualTTSPlaying]);
|
||||
|
||||
const pause = useCallback(() => {
|
||||
// Streaming player currently supports stop/resume via restart, not true pause.
|
||||
stop();
|
||||
}, [stop]);
|
||||
|
||||
const play = useCallback(
|
||||
async (text: string, voice?: string, speed?: number) => {
|
||||
// Stop any existing playback
|
||||
stop();
|
||||
suppressPlayerErrorsRef.current = false;
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const player = new StreamingTTSPlayer({
|
||||
onPlayingChange: (playing) => {
|
||||
setIsPlaying(playing);
|
||||
setManualTTSPlaying(playing);
|
||||
if (playing) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
onError: (playbackError) => {
|
||||
if (suppressPlayerErrorsRef.current) {
|
||||
return;
|
||||
}
|
||||
console.error("Voice playback error:", playbackError);
|
||||
setManualTTSPlaying(false);
|
||||
setError(playbackError);
|
||||
setIsLoading(false);
|
||||
setIsPlaying(false);
|
||||
},
|
||||
});
|
||||
playerRef.current = player;
|
||||
player.setMuted(isTTSMuted);
|
||||
|
||||
await player.speak(text, voice, speed);
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// Request was cancelled, not an error
|
||||
return;
|
||||
}
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Speech synthesis failed";
|
||||
setError(message);
|
||||
setIsLoading(false);
|
||||
setIsPlaying(false);
|
||||
setManualTTSPlaying(false);
|
||||
}
|
||||
},
|
||||
[isTTSMuted, setManualTTSPlaying, stop]
|
||||
);
|
||||
|
||||
return {
|
||||
isPlaying,
|
||||
isLoading,
|
||||
error,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
};
|
||||
}
|
||||
35
web/src/hooks/useVoiceProviders.ts
Normal file
35
web/src/hooks/useVoiceProviders.ts
Normal file
@@ -0,0 +1,35 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export interface VoiceProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: string;
|
||||
is_default_stt: boolean;
|
||||
is_default_tts: boolean;
|
||||
stt_model: string | null;
|
||||
tts_model: string | null;
|
||||
default_voice: string | null;
|
||||
has_api_key: boolean;
|
||||
target_uri: string | null;
|
||||
}
|
||||
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
export function useVoiceProviders() {
|
||||
const { data, error, isLoading, mutate } = useSWR<VoiceProviderView[]>(
|
||||
VOICE_PROVIDERS_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
providers: data ?? [],
|
||||
isLoading,
|
||||
error,
|
||||
refresh: mutate,
|
||||
};
|
||||
}
|
||||
500
web/src/hooks/useVoiceRecorder.ts
Normal file
500
web/src/hooks/useVoiceRecorder.ts
Normal file
@@ -0,0 +1,500 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
// Target format for OpenAI Realtime API
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
text?: string;
|
||||
message?: string;
|
||||
is_final?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderOptions {
|
||||
/** Called when VAD detects silence and final transcript is received */
|
||||
onFinalTranscript?: (text: string) => void;
|
||||
/** If true, automatically stop recording when VAD detects silence */
|
||||
autoStopOnSilence?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderReturn {
|
||||
isRecording: boolean;
|
||||
isProcessing: boolean;
|
||||
isMuted: boolean;
|
||||
error: string | null;
|
||||
liveTranscript: string;
|
||||
/** Current microphone audio level (0-1, RMS-based) */
|
||||
audioLevel: number;
|
||||
startRecording: () => Promise<void>;
|
||||
stopRecording: () => Promise<string | null>;
|
||||
setMuted: (muted: boolean) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Encapsulates all browser resources for a voice recording session.
|
||||
* Manages WebSocket, Web Audio API, and audio buffering.
|
||||
*/
|
||||
class VoiceRecorderSession {
|
||||
// Browser resources
|
||||
private websocket: WebSocket | null = null;
|
||||
private audioContext: AudioContext | null = null;
|
||||
private scriptNode: ScriptProcessorNode | null = null;
|
||||
private sourceNode: MediaStreamAudioSourceNode | null = null;
|
||||
private mediaStream: MediaStream | null = null;
|
||||
private sendInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
// State
|
||||
private audioBuffer: Float32Array[] = [];
|
||||
private transcript = "";
|
||||
private stopResolver: ((text: string | null) => void) | null = null;
|
||||
private isActive = false;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
private onFinalTranscript: ((text: string) => void) | null;
|
||||
private onError: (error: string) => void;
|
||||
private onAudioLevel: (level: number) => void;
|
||||
private onSilenceTimeout: (() => void) | null;
|
||||
private onVADStop: (() => void) | null;
|
||||
private autoStopOnSilence: boolean;
|
||||
|
||||
constructor(
|
||||
onTranscriptChange: (text: string) => void,
|
||||
onFinalTranscript: ((text: string) => void) | null,
|
||||
onError: (error: string) => void,
|
||||
onAudioLevel: (level: number) => void,
|
||||
onSilenceTimeout?: () => void,
|
||||
autoStopOnSilence?: boolean,
|
||||
onVADStop?: () => void
|
||||
) {
|
||||
this.onTranscriptChange = onTranscriptChange;
|
||||
this.onFinalTranscript = onFinalTranscript;
|
||||
this.onError = onError;
|
||||
this.onAudioLevel = onAudioLevel;
|
||||
this.onSilenceTimeout = onSilenceTimeout || null;
|
||||
this.autoStopOnSilence = autoStopOnSilence ?? false;
|
||||
this.onVADStop = onVADStop || null;
|
||||
}
|
||||
|
||||
get recording(): boolean {
|
||||
return this.isActive;
|
||||
}
|
||||
|
||||
get currentTranscript(): string {
|
||||
return this.transcript;
|
||||
}
|
||||
|
||||
setMuted(muted: boolean): void {
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getAudioTracks().forEach((track) => {
|
||||
track.enabled = !muted;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.isActive) return;
|
||||
|
||||
this.cleanup();
|
||||
this.transcript = "";
|
||||
this.audioBuffer = [];
|
||||
|
||||
// Get microphone
|
||||
this.mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
channelCount: 1,
|
||||
sampleRate: { ideal: TARGET_SAMPLE_RATE },
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
},
|
||||
});
|
||||
|
||||
// Get WS token and connect WebSocket
|
||||
const wsUrl = await this.getWebSocketUrl();
|
||||
this.websocket = new WebSocket(wsUrl);
|
||||
this.websocket.onmessage = this.handleMessage;
|
||||
this.websocket.onerror = () => this.onError("Connection failed");
|
||||
this.websocket.onclose = () => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
};
|
||||
|
||||
await this.waitForConnection();
|
||||
|
||||
// Restore error handler after connection (waitForConnection overwrites it)
|
||||
this.websocket.onerror = () => this.onError("Connection failed");
|
||||
|
||||
// Set up audio capture
|
||||
this.audioContext = new AudioContext({ sampleRate: TARGET_SAMPLE_RATE });
|
||||
this.sourceNode = this.audioContext.createMediaStreamSource(
|
||||
this.mediaStream
|
||||
);
|
||||
this.scriptNode = this.audioContext.createScriptProcessor(4096, 1, 1);
|
||||
|
||||
this.scriptNode.onaudioprocess = (event) => {
|
||||
const inputData = event.inputBuffer.getChannelData(0);
|
||||
this.audioBuffer.push(new Float32Array(inputData));
|
||||
|
||||
// Compute RMS audio level (0-1) for waveform visualization
|
||||
let sum = 0;
|
||||
for (let i = 0; i < inputData.length; i++) {
|
||||
sum += inputData[i]! * inputData[i]!;
|
||||
}
|
||||
const rms = Math.sqrt(sum / inputData.length);
|
||||
// Scale RMS to a more visible range (raw RMS is usually very small)
|
||||
this.onAudioLevel(Math.min(1, rms * 5));
|
||||
};
|
||||
|
||||
this.sourceNode.connect(this.scriptNode);
|
||||
this.scriptNode.connect(this.audioContext.destination);
|
||||
|
||||
// Start sending audio chunks
|
||||
this.sendInterval = setInterval(
|
||||
() => this.sendAudioBuffer(),
|
||||
CHUNK_INTERVAL_MS
|
||||
);
|
||||
this.isActive = true;
|
||||
}
|
||||
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
this.sendInterval = null;
|
||||
}
|
||||
if (this.scriptNode) {
|
||||
this.scriptNode.disconnect();
|
||||
this.scriptNode = null;
|
||||
}
|
||||
if (this.sourceNode) {
|
||||
this.sourceNode.disconnect();
|
||||
this.sourceNode = null;
|
||||
}
|
||||
if (this.audioContext) {
|
||||
this.audioContext.close();
|
||||
this.audioContext = null;
|
||||
}
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getTracks().forEach((track) => track.stop());
|
||||
this.mediaStream = null;
|
||||
}
|
||||
|
||||
this.audioBuffer = [];
|
||||
this.isActive = false;
|
||||
|
||||
// Get final transcript from server
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
return new Promise((resolve) => {
|
||||
this.stopResolver = resolve;
|
||||
this.websocket!.send(JSON.stringify({ type: "end" }));
|
||||
|
||||
// Timeout fallback
|
||||
setTimeout(() => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}, 3000);
|
||||
});
|
||||
}
|
||||
|
||||
return this.transcript || null;
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
if (this.audioContext) this.audioContext.close();
|
||||
if (this.mediaStream) this.mediaStream.getTracks().forEach((t) => t.stop());
|
||||
if (this.websocket) this.websocket.close();
|
||||
|
||||
this.sendInterval = null;
|
||||
this.scriptNode = null;
|
||||
this.sourceNode = null;
|
||||
this.audioContext = null;
|
||||
this.mediaStream = null;
|
||||
this.websocket = null;
|
||||
this.isActive = false;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/transcribe/stream"
|
||||
: "/api/voice/transcribe/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
private waitForConnection(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!this.websocket) return reject(new Error("No WebSocket"));
|
||||
|
||||
const timeout = setTimeout(
|
||||
() => reject(new Error("Connection timeout")),
|
||||
5000
|
||||
);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
resolve();
|
||||
};
|
||||
this.websocket.onerror = () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Connection failed"));
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback
|
||||
if (this.onFinalTranscript) {
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
this.onError(data.message || "Transcription error");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to parse transcript message:", e);
|
||||
}
|
||||
};
|
||||
|
||||
private resetBackendTranscript(): void {
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "reset" }));
|
||||
}
|
||||
}
|
||||
|
||||
private sendAudioBuffer(): void {
|
||||
if (
|
||||
!this.websocket ||
|
||||
this.websocket.readyState !== WebSocket.OPEN ||
|
||||
!this.audioContext ||
|
||||
this.audioBuffer.length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Concatenate buffered chunks
|
||||
const totalLength = this.audioBuffer.reduce(
|
||||
(sum, chunk) => sum + chunk.length,
|
||||
0
|
||||
);
|
||||
|
||||
// Prevent buffer overflow
|
||||
if (totalLength > this.audioContext.sampleRate * 0.5 * 2) {
|
||||
this.audioBuffer = this.audioBuffer.slice(-10);
|
||||
return;
|
||||
}
|
||||
|
||||
const concatenated = new Float32Array(totalLength);
|
||||
let offset = 0;
|
||||
for (const chunk of this.audioBuffer) {
|
||||
concatenated.set(chunk, offset);
|
||||
offset += chunk.length;
|
||||
}
|
||||
this.audioBuffer = [];
|
||||
|
||||
// Resample and convert to PCM16
|
||||
const resampled = this.resampleAudio(
|
||||
concatenated,
|
||||
this.audioContext.sampleRate
|
||||
);
|
||||
const pcm16 = this.float32ToInt16(resampled);
|
||||
|
||||
this.websocket.send(pcm16.buffer);
|
||||
}
|
||||
|
||||
private resampleAudio(input: Float32Array, inputRate: number): Float32Array {
|
||||
if (inputRate === TARGET_SAMPLE_RATE) return input;
|
||||
|
||||
const ratio = inputRate / TARGET_SAMPLE_RATE;
|
||||
const outputLength = Math.round(input.length / ratio);
|
||||
const output = new Float32Array(outputLength);
|
||||
|
||||
for (let i = 0; i < outputLength; i++) {
|
||||
const srcIndex = i * ratio;
|
||||
const floor = Math.floor(srcIndex);
|
||||
const ceil = Math.min(floor + 1, input.length - 1);
|
||||
const fraction = srcIndex - floor;
|
||||
output[i] = input[floor]! * (1 - fraction) + input[ceil]! * fraction;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private float32ToInt16(float32: Float32Array): Int16Array {
|
||||
const int16 = new Int16Array(float32.length);
|
||||
for (let i = 0; i < float32.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, float32[i]!));
|
||||
int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||
}
|
||||
return int16;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for voice recording with streaming transcription.
|
||||
*/
|
||||
export function useVoiceRecorder(
|
||||
options?: UseVoiceRecorderOptions
|
||||
): UseVoiceRecorderReturn {
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [isMuted, setIsMutedState] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [liveTranscript, setLiveTranscript] = useState("");
|
||||
const [audioLevel, setAudioLevel] = useState(0);
|
||||
|
||||
const sessionRef = useRef<VoiceRecorderSession | null>(null);
|
||||
const onFinalTranscriptRef = useRef(options?.onFinalTranscript);
|
||||
const autoStopOnSilenceRef = useRef(options?.autoStopOnSilence ?? true); // Default to true
|
||||
|
||||
// Keep callback ref in sync
|
||||
useEffect(() => {
|
||||
onFinalTranscriptRef.current = options?.onFinalTranscript;
|
||||
autoStopOnSilenceRef.current = options?.autoStopOnSilence ?? true;
|
||||
}, [options?.onFinalTranscript, options?.autoStopOnSilence]);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
sessionRef.current?.cleanup();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const startRecording = useCallback(async () => {
|
||||
if (sessionRef.current?.recording) return;
|
||||
|
||||
setError(null);
|
||||
setLiveTranscript("");
|
||||
|
||||
// Clear any stale, inactive session before starting a new one.
|
||||
if (sessionRef.current && !sessionRef.current.recording) {
|
||||
sessionRef.current.cleanup();
|
||||
sessionRef.current = null;
|
||||
}
|
||||
|
||||
// Create VAD stop handler that will stop the session
|
||||
const currentSession = new VoiceRecorderSession(
|
||||
setLiveTranscript,
|
||||
(text) => onFinalTranscriptRef.current?.(text),
|
||||
setError,
|
||||
setAudioLevel,
|
||||
undefined, // onSilenceTimeout
|
||||
autoStopOnSilenceRef.current,
|
||||
() => {
|
||||
// Stop only this session instance, and only clear recording state if it
|
||||
// is still the active session when stop resolves.
|
||||
currentSession.stop().then(() => {
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(false);
|
||||
setIsMutedState(false);
|
||||
sessionRef.current = null;
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
sessionRef.current = currentSession;
|
||||
|
||||
try {
|
||||
await currentSession.start();
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(true);
|
||||
}
|
||||
} catch (err) {
|
||||
currentSession.cleanup();
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to start recording"
|
||||
);
|
||||
if (sessionRef.current === currentSession) {
|
||||
sessionRef.current = null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const stopRecording = useCallback(async (): Promise<string | null> => {
|
||||
if (!sessionRef.current) return null;
|
||||
const currentSession = sessionRef.current;
|
||||
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const transcript = await currentSession.stop();
|
||||
return transcript;
|
||||
} finally {
|
||||
// Only clear state if this is still the active session.
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(false);
|
||||
setIsMutedState(false); // Reset mute state when recording stops
|
||||
sessionRef.current = null;
|
||||
}
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setMuted = useCallback((muted: boolean) => {
|
||||
setIsMutedState(muted);
|
||||
sessionRef.current?.setMuted(muted);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
audioLevel,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
};
|
||||
}
|
||||
25
web/src/hooks/useVoiceStatus.ts
Normal file
25
web/src/hooks/useVoiceStatus.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
interface VoiceStatus {
|
||||
stt_enabled: boolean;
|
||||
tts_enabled: boolean;
|
||||
}
|
||||
|
||||
export function useVoiceStatus() {
|
||||
const { data, error, isLoading } = useSWR<VoiceStatus>(
|
||||
"/api/voice/status",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
sttEnabled: data?.stt_enabled ?? false,
|
||||
ttsEnabled: data?.tts_enabled ?? false,
|
||||
isLoading,
|
||||
error,
|
||||
};
|
||||
}
|
||||
150
web/src/hooks/useWebSocket.ts
Normal file
150
web/src/hooks/useWebSocket.ts
Normal file
@@ -0,0 +1,150 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
export type WebSocketStatus =
|
||||
| "connecting"
|
||||
| "connected"
|
||||
| "disconnected"
|
||||
| "error";
|
||||
|
||||
export interface UseWebSocketOptions<T> {
|
||||
/** URL to connect to */
|
||||
url: string;
|
||||
/** Called when a message is received */
|
||||
onMessage?: (data: T) => void;
|
||||
/** Called when connection opens */
|
||||
onOpen?: () => void;
|
||||
/** Called when connection closes */
|
||||
onClose?: () => void;
|
||||
/** Called on error */
|
||||
onError?: (error: Event) => void;
|
||||
/** Auto-connect on mount */
|
||||
autoConnect?: boolean;
|
||||
}
|
||||
|
||||
export interface UseWebSocketReturn<T> {
|
||||
/** Current connection status */
|
||||
status: WebSocketStatus;
|
||||
/** Send JSON data */
|
||||
sendJson: (data: T) => void;
|
||||
/** Send binary data */
|
||||
sendBinary: (data: Blob | ArrayBuffer) => void;
|
||||
/** Connect to WebSocket */
|
||||
connect: () => Promise<void>;
|
||||
/** Disconnect from WebSocket */
|
||||
disconnect: () => void;
|
||||
}
|
||||
|
||||
export function useWebSocket<TReceive = unknown, TSend = unknown>({
|
||||
url,
|
||||
onMessage,
|
||||
onOpen,
|
||||
onClose,
|
||||
onError,
|
||||
autoConnect = false,
|
||||
}: UseWebSocketOptions<TReceive>): UseWebSocketReturn<TSend> {
|
||||
const [status, setStatus] = useState<WebSocketStatus>("disconnected");
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const onMessageRef = useRef(onMessage);
|
||||
const onOpenRef = useRef(onOpen);
|
||||
const onCloseRef = useRef(onClose);
|
||||
const onErrorRef = useRef(onError);
|
||||
|
||||
// Keep refs updated
|
||||
useEffect(() => {
|
||||
onMessageRef.current = onMessage;
|
||||
onOpenRef.current = onOpen;
|
||||
onCloseRef.current = onClose;
|
||||
onErrorRef.current = onError;
|
||||
}, [onMessage, onOpen, onClose, onError]);
|
||||
|
||||
const connect = useCallback(async (): Promise<void> => {
|
||||
if (
|
||||
wsRef.current?.readyState === WebSocket.OPEN ||
|
||||
wsRef.current?.readyState === WebSocket.CONNECTING
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus("connecting");
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const ws = new WebSocket(url);
|
||||
wsRef.current = ws;
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
ws.close();
|
||||
reject(new Error("WebSocket connection timeout"));
|
||||
}, 10000);
|
||||
|
||||
ws.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("connected");
|
||||
onOpenRef.current?.();
|
||||
resolve();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data) as TReceive;
|
||||
onMessageRef.current?.(data);
|
||||
} catch {
|
||||
// Non-JSON message, ignore or handle differently
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("disconnected");
|
||||
onCloseRef.current?.();
|
||||
wsRef.current = null;
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("error");
|
||||
onErrorRef.current?.(error);
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
});
|
||||
}, [url]);
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
wsRef.current = null;
|
||||
}
|
||||
setStatus("disconnected");
|
||||
}, []);
|
||||
|
||||
const sendJson = useCallback((data: TSend) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify(data));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const sendBinary = useCallback((data: Blob | ArrayBuffer) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(data);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Auto-connect if enabled
|
||||
useEffect(() => {
|
||||
if (autoConnect) {
|
||||
connect().catch(() => {
|
||||
// Error handled via onError callback
|
||||
});
|
||||
}
|
||||
return () => {
|
||||
disconnect();
|
||||
};
|
||||
}, [autoConnect, connect, disconnect]);
|
||||
|
||||
return {
|
||||
status,
|
||||
sendJson,
|
||||
sendBinary,
|
||||
connect,
|
||||
disconnect,
|
||||
};
|
||||
}
|
||||
58
web/src/lib/admin/voice/svc.ts
Normal file
58
web/src/lib/admin/voice/svc.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
export async function activateVoiceProvider(
|
||||
providerId: number,
|
||||
mode: "stt" | "tts",
|
||||
ttsModel?: string
|
||||
): Promise<Response> {
|
||||
const url = new URL(
|
||||
`${VOICE_PROVIDERS_URL}/${providerId}/activate-${mode}`,
|
||||
window.location.origin
|
||||
);
|
||||
if (mode === "tts" && ttsModel) {
|
||||
url.searchParams.set("tts_model", ttsModel);
|
||||
}
|
||||
return fetch(url.toString(), { method: "POST" });
|
||||
}
|
||||
|
||||
export async function deactivateVoiceProvider(
|
||||
providerId: number,
|
||||
mode: "stt" | "tts"
|
||||
): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/${providerId}/deactivate-${mode}`, {
|
||||
method: "POST",
|
||||
});
|
||||
}
|
||||
|
||||
export async function testVoiceProvider(request: {
|
||||
provider_type: string;
|
||||
api_key?: string;
|
||||
target_uri?: string;
|
||||
use_stored_key?: boolean;
|
||||
}): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/test`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
|
||||
export async function upsertVoiceProvider(
|
||||
request: Record<string, unknown>
|
||||
): Promise<Response> {
|
||||
return fetch(VOICE_PROVIDERS_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
|
||||
export async function fetchVoicesByType(
|
||||
providerType: string
|
||||
): Promise<Response> {
|
||||
return fetch(`/api/admin/voice/voices?provider_type=${providerType}`);
|
||||
}
|
||||
|
||||
export async function fetchLLMProviders(): Promise<Response> {
|
||||
return fetch("/api/admin/llm/provider");
|
||||
}
|
||||
@@ -151,6 +151,17 @@ export function formatMmDdYyyy(d: string): string {
|
||||
return `${date.getMonth() + 1}/${date.getDate()}/${date.getFullYear()}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a duration in seconds as MM:SS (e.g. 65 → "01:05").
|
||||
*/
|
||||
export function formatElapsedTime(totalSeconds: number): string {
|
||||
const minutes = Math.floor(totalSeconds / 60);
|
||||
const seconds = totalSeconds % 60;
|
||||
return `${minutes.toString().padStart(2, "0")}:${seconds
|
||||
.toString()
|
||||
.padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
export const getFormattedDateTime = (date: Date | null) => {
|
||||
if (!date) return null;
|
||||
|
||||
|
||||
614
web/src/lib/streamingTTS.ts
Normal file
614
web/src/lib/streamingTTS.ts
Normal file
@@ -0,0 +1,614 @@
|
||||
/**
|
||||
* Real-time streaming TTS using HTTP streaming with MediaSource Extensions.
|
||||
* Plays audio chunks as they arrive for smooth, low-latency playback.
|
||||
*/
|
||||
|
||||
/**
|
||||
* HTTPStreamingTTSPlayer - Uses HTTP streaming with MediaSource Extensions
|
||||
* for smooth, gapless audio playback. This is the recommended approach for
|
||||
* real-time TTS as it properly handles MP3 frame boundaries.
|
||||
*/
|
||||
export class HTTPStreamingTTSPlayer {
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private mediaSourceUrl: string | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private streamComplete: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private abortController: AbortController | null = null;
|
||||
private isMuted: boolean = false;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private getAPIUrl(): string {
|
||||
// Always go through the frontend proxy to ensure cookies are sent correctly
|
||||
// The Next.js proxy at /api/* forwards to the backend
|
||||
return "/api/voice/synthesize";
|
||||
}
|
||||
|
||||
/**
|
||||
* Speak text using HTTP streaming with real-time playback.
|
||||
* Audio begins playing as soon as the first chunks arrive.
|
||||
*/
|
||||
async speak(
|
||||
text: string,
|
||||
voice?: string,
|
||||
speed: number = 1.0
|
||||
): Promise<void> {
|
||||
// Cleanup any previous playback
|
||||
this.cleanup();
|
||||
|
||||
// Create abort controller for this request
|
||||
this.abortController = new AbortController();
|
||||
|
||||
// Build URL with query params
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
params.set("speed", speed.toString());
|
||||
|
||||
const url = `${this.getAPIUrl()}?${params}`;
|
||||
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
// Fallback to simple buffered playback
|
||||
return this.fallbackSpeak(url);
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.mediaSourceUrl = URL.createObjectURL(this.mediaSource);
|
||||
this.audioElement.src = this.mediaSourceUrl;
|
||||
this.audioElement.muted = this.isMuted;
|
||||
|
||||
// Set up audio element event handlers
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
if (!this.mediaSource) {
|
||||
reject(new Error("MediaSource not initialized"));
|
||||
return;
|
||||
}
|
||||
|
||||
this.mediaSource.onsourceopen = () => {
|
||||
try {
|
||||
// Create SourceBuffer for MP3
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
|
||||
// MediaSource doesn't have onerror in all browsers, use onsourceclose as fallback
|
||||
this.mediaSource.onsourceclose = () => {
|
||||
if (this.mediaSource?.readyState === "closed") {
|
||||
reject(new Error("MediaSource closed unexpectedly"));
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Start fetching and streaming audio
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`TTS request failed: ${response.status} - ${errorText}`
|
||||
);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Start playback as soon as we have some data
|
||||
let firstChunk = true;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
this.streamComplete = true;
|
||||
// End the stream when all chunks are appended
|
||||
this.finalizeStream();
|
||||
break;
|
||||
}
|
||||
|
||||
if (value) {
|
||||
this.pendingChunks.push(value);
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (firstChunk && this.audioElement) {
|
||||
firstChunk = false;
|
||||
// Small delay to buffer a bit before starting
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback start errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
this.onError?.(err instanceof Error ? err.message : "TTS error");
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process next chunk from the queue.
|
||||
*/
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
// Try next chunk
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize the stream when all data has been received.
|
||||
*/
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
// Wait for remaining chunks to be appended
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore errors when ending stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback for browsers that don't support MediaSource Extensions.
|
||||
* Buffers all audio before playing.
|
||||
*/
|
||||
private async fallbackSpeak(url: string): Promise<void> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController?.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const audioData = await response.arrayBuffer();
|
||||
|
||||
const blob = new Blob([audioData], { type: "audio/mpeg" });
|
||||
const audioUrl = URL.createObjectURL(blob);
|
||||
|
||||
this.audioElement = new Audio(audioUrl);
|
||||
this.audioElement.muted = this.isMuted;
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
URL.revokeObjectURL(audioUrl);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
};
|
||||
|
||||
await this.audioElement.play();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop playback and cleanup resources.
|
||||
*/
|
||||
stop(): void {
|
||||
// Abort any ongoing request
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
}
|
||||
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
setMuted(muted: boolean): void {
|
||||
this.isMuted = muted;
|
||||
if (this.audioElement) {
|
||||
this.audioElement.muted = muted;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup all resources.
|
||||
*/
|
||||
private cleanup(): void {
|
||||
// Revoke Object URL to prevent memory leak
|
||||
if (this.mediaSourceUrl) {
|
||||
URL.revokeObjectURL(this.mediaSourceUrl);
|
||||
this.mediaSourceUrl = null;
|
||||
}
|
||||
|
||||
// Stop and cleanup audio element
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
// Cleanup MediaSource
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.streamComplete = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* WebSocketStreamingTTSPlayer - Uses WebSocket for bidirectional streaming.
|
||||
* Useful for scenarios where you want to stream text in and get audio out
|
||||
* incrementally (e.g., as LLM generates text).
|
||||
*/
|
||||
export class WebSocketStreamingTTSPlayer {
|
||||
private websocket: WebSocket | null = null;
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private mediaSourceUrl: string | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private hasStartedPlayback: boolean = false;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/synthesize/stream"
|
||||
: "/api/voice/synthesize/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
async connect(voice?: string, speed?: number): Promise<void> {
|
||||
// Cleanup any previous connection
|
||||
this.cleanup();
|
||||
|
||||
// Check MediaSource support
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
throw new Error("MediaSource Extensions not supported");
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.mediaSourceUrl = URL.createObjectURL(this.mediaSource);
|
||||
this.audioElement.src = this.mediaSourceUrl;
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
this.mediaSource!.onsourceopen = () => {
|
||||
try {
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Connect WebSocket
|
||||
const url = await this.getWebSocketUrl();
|
||||
return new Promise((resolve, reject) => {
|
||||
this.websocket = new WebSocket(url);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
// Send initial config
|
||||
this.websocket?.send(
|
||||
JSON.stringify({
|
||||
type: "config",
|
||||
voice: voice,
|
||||
speed: speed || 1.0,
|
||||
})
|
||||
);
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.websocket.onerror = () => {
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
|
||||
this.websocket.onmessage = async (event) => {
|
||||
if (event.data instanceof Blob) {
|
||||
// Audio chunk received
|
||||
const arrayBuffer = await event.data.arrayBuffer();
|
||||
this.pendingChunks.push(new Uint8Array(arrayBuffer));
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (!this.hasStartedPlayback && this.audioElement) {
|
||||
this.hasStartedPlayback = true;
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
} else {
|
||||
// JSON message
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "audio_done") {
|
||||
this.finalizeStream();
|
||||
} else if (data.type === "error") {
|
||||
this.onError?.(data.message);
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
this.websocket.onclose = () => {
|
||||
this.finalizeStream();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async speak(text: string): Promise<void> {
|
||||
if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
|
||||
throw new Error("WebSocket not connected");
|
||||
}
|
||||
|
||||
this.websocket.send(
|
||||
JSON.stringify({
|
||||
type: "synthesize",
|
||||
text: text,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
stop(): void {
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "end" }));
|
||||
this.websocket.close();
|
||||
}
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
private cleanup(): void {
|
||||
if (this.websocket) {
|
||||
this.websocket.close();
|
||||
this.websocket = null;
|
||||
}
|
||||
|
||||
// Revoke Object URL to prevent memory leak
|
||||
if (this.mediaSourceUrl) {
|
||||
URL.revokeObjectURL(this.mediaSourceUrl);
|
||||
this.mediaSourceUrl = null;
|
||||
}
|
||||
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.hasStartedPlayback = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
// Export the HTTP player as the default/recommended option
|
||||
export { HTTPStreamingTTSPlayer as StreamingTTSPlayer };
|
||||
@@ -32,6 +32,10 @@ interface UserPreferences {
|
||||
theme_preference: ThemePreference | null;
|
||||
chat_background: string | null;
|
||||
default_app_mode: "AUTO" | "CHAT" | "SEARCH";
|
||||
// Voice preferences
|
||||
voice_auto_send?: boolean;
|
||||
voice_auto_playback?: boolean;
|
||||
voice_playback_speed?: number;
|
||||
}
|
||||
|
||||
export interface MemoryItem {
|
||||
|
||||
@@ -46,6 +46,11 @@ interface UserContextType {
|
||||
updateUserChatBackground: (chatBackground: string | null) => Promise<void>;
|
||||
updateUserDefaultModel: (defaultModel: string | null) => Promise<void>;
|
||||
updateUserDefaultAppMode: (mode: "CHAT" | "SEARCH") => Promise<void>;
|
||||
updateUserVoiceSettings: (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => Promise<void>;
|
||||
}
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
@@ -460,6 +465,50 @@ export function UserProvider({
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserVoiceSettings = async (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
if (prevUser) {
|
||||
return {
|
||||
...prevUser,
|
||||
preferences: {
|
||||
...prevUser.preferences,
|
||||
voice_auto_send:
|
||||
settings.auto_send ?? prevUser.preferences.voice_auto_send,
|
||||
voice_auto_playback:
|
||||
settings.auto_playback ??
|
||||
prevUser.preferences.voice_auto_playback,
|
||||
voice_playback_speed:
|
||||
settings.playback_speed ??
|
||||
prevUser.preferences.voice_playback_speed,
|
||||
},
|
||||
};
|
||||
}
|
||||
return prevUser;
|
||||
});
|
||||
|
||||
const response = await fetch("/api/voice/settings", {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
await refreshUser();
|
||||
throw new Error("Failed to update voice settings");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating voice settings:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const refreshUser = async () => {
|
||||
await fetchUser();
|
||||
};
|
||||
@@ -478,6 +527,7 @@ export function UserProvider({
|
||||
updateUserChatBackground,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
toggleAgentPinnedStatus,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
|
||||
1063
web/src/providers/VoiceModeProvider.tsx
Normal file
1063
web/src/providers/VoiceModeProvider.tsx
Normal file
File diff suppressed because it is too large
Load Diff
@@ -210,10 +210,15 @@ describe("InputComboBox", () => {
|
||||
|
||||
await user.type(input, "app");
|
||||
|
||||
// Search should only show matching options by default
|
||||
// In non-strict mode, searching shows:
|
||||
// 1) a create option for the current input and
|
||||
// 2) matched options.
|
||||
const options = screen.getAllByRole("option");
|
||||
expect(options.length).toBe(1);
|
||||
expect(options[0]!.textContent).toBe("Apple");
|
||||
expect(options.length).toBe(2);
|
||||
expect(screen.getByLabelText('Create "app"')).toBeInTheDocument();
|
||||
expect(
|
||||
options.some((option) => option.textContent?.includes("Apple"))
|
||||
).toBe(true);
|
||||
expect(screen.queryByText("Banana")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
||||
@@ -130,6 +130,7 @@ const InputComboBox = ({
|
||||
leftSearchIcon = false,
|
||||
rightSection,
|
||||
separatorLabel = "Other options",
|
||||
showAddPrefix = false,
|
||||
showOtherOptions = false,
|
||||
...rest
|
||||
}: WithoutStyles<InputComboBoxProps>) => {
|
||||
@@ -157,14 +158,11 @@ const InputComboBox = ({
|
||||
const visibleUnmatchedOptions =
|
||||
hasSearchTerm && showOtherOptions ? unmatchedOptions : [];
|
||||
|
||||
// Whether to show the create option (only when no partial matches)
|
||||
const showCreateOption =
|
||||
!strict &&
|
||||
hasSearchTerm &&
|
||||
inputValue.trim() !== "" &&
|
||||
matchedOptions.length === 0;
|
||||
// Whether to show the create option (always show when typing in non-strict mode)
|
||||
const showCreateOption = !strict && hasSearchTerm && inputValue.trim() !== "";
|
||||
|
||||
// Combined list for keyboard navigation (includes create option when shown)
|
||||
// Only show matched options when searching (hide unmatched)
|
||||
const allVisibleOptions = useMemo(() => {
|
||||
const baseOptions = [...matchedOptions, ...visibleUnmatchedOptions];
|
||||
if (showCreateOption) {
|
||||
@@ -450,6 +448,7 @@ const InputComboBox = ({
|
||||
inputValue={inputValue}
|
||||
allowCreate={!strict}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</>
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ interface ComboBoxDropdownProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -58,6 +60,7 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
@@ -132,6 +135,7 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue={inputValue}
|
||||
allowCreate={allowCreate}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</div>,
|
||||
document.body
|
||||
|
||||
@@ -24,6 +24,8 @@ interface OptionsListProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -45,6 +47,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
}) => {
|
||||
// Index offset for other options when create option is shown
|
||||
const indexOffset = showCreateOption ? 1 : 0;
|
||||
@@ -70,7 +73,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
data-index={0}
|
||||
role="option"
|
||||
aria-selected={false}
|
||||
aria-label={`Create "${inputValue}"`}
|
||||
aria-label={`${showAddPrefix ? "Add" : "Create"} "${inputValue}"`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onSelect({ value: inputValue, label: inputValue });
|
||||
@@ -81,19 +84,48 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
onMouseEnter={() => onMouseEnter(0)}
|
||||
onMouseMove={onMouseMove}
|
||||
className={cn(
|
||||
"px-3 py-2 cursor-pointer transition-colors",
|
||||
"cursor-pointer transition-colors",
|
||||
"flex items-center justify-between rounded-08",
|
||||
highlightedIndex === 0 && "bg-background-tint-02",
|
||||
"hover:bg-background-tint-02"
|
||||
"hover:bg-background-tint-02",
|
||||
showAddPrefix ? "px-1.5 py-1.5" : "px-3 py-2"
|
||||
)}
|
||||
>
|
||||
<span className="font-main-ui-action text-text-04 truncate min-w-0">
|
||||
{inputValue}
|
||||
<span
|
||||
className={cn(
|
||||
"font-main-ui-action truncate min-w-0",
|
||||
showAddPrefix ? "px-1" : ""
|
||||
)}
|
||||
>
|
||||
{showAddPrefix ? (
|
||||
<>
|
||||
<span className="text-text-03">Add</span>
|
||||
<span className="text-text-04">{` ${inputValue}`}</span>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-text-04">{inputValue}</span>
|
||||
)}
|
||||
</span>
|
||||
<SvgPlus className="w-4 h-4 text-text-03 flex-shrink-0 ml-2" />
|
||||
<SvgPlus
|
||||
className={cn(
|
||||
"w-4 h-4 flex-shrink-0",
|
||||
showAddPrefix ? "text-text-04 mx-1" : "text-text-03 ml-2"
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Separator - show when there are options to display */}
|
||||
{separatorLabel &&
|
||||
(matchedOptions.length > 0 ||
|
||||
(!hasSearchTerm && unmatchedOptions.length > 0)) && (
|
||||
<div className="px-3 py-1">
|
||||
<Text as="p" text03 secondaryBody>
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Matched/Filtered Options */}
|
||||
{matchedOptions.map((option, idx) => {
|
||||
const globalIndex = idx + indexOffset;
|
||||
@@ -116,37 +148,27 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Separator - only show if there are unmatched options and a search term */}
|
||||
{hasSearchTerm && unmatchedOptions.length > 0 && (
|
||||
<div className="px-3 py-2 pt-3">
|
||||
<div className="border-t border-border-01 pt-2">
|
||||
<Text as="p" text04 secondaryBody className="text-text-02">
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Unmatched Options */}
|
||||
{unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{/* Unmatched Options - only show when NOT searching */}
|
||||
{!hasSearchTerm &&
|
||||
unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useCallback, useMemo, RefObject } from "react";
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import { ComboBoxOption } from "./types";
|
||||
|
||||
// =============================================================================
|
||||
@@ -19,6 +19,7 @@ export function useComboBoxState({ value, options }: UseComboBoxStateProps) {
|
||||
const [inputValue, setInputValue] = useState(value);
|
||||
const [highlightedIndex, setHighlightedIndex] = useState(-1);
|
||||
const [isKeyboardNav, setIsKeyboardNav] = useState(false);
|
||||
const prevIsOpenRef = useRef(false);
|
||||
|
||||
// Sync inputValue with the external value prop.
|
||||
// When the dropdown is closed, always reflect the controlled value.
|
||||
|
||||
@@ -40,6 +40,8 @@ export interface InputComboBoxProps
|
||||
rightSection?: React.ReactNode;
|
||||
/** Label for the separator between matched and unmatched options */
|
||||
separatorLabel?: string;
|
||||
/** Show "Add" prefix in create option (e.g., "Add [value]") */
|
||||
showAddPrefix?: boolean;
|
||||
/**
|
||||
* When true, keep non-matching options visible under a separator while searching.
|
||||
* Defaults to false so search results are strictly filtered.
|
||||
|
||||
@@ -751,6 +751,7 @@ function ChatPreferencesSettings() {
|
||||
updateUserShortcuts,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
} = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
@@ -767,6 +768,43 @@ function ChatPreferencesSettings() {
|
||||
onSuccess: () => toast.success("Preferences saved"),
|
||||
onError: () => toast.error("Failed to save preferences"),
|
||||
});
|
||||
const [draftVoicePlaybackSpeed, setDraftVoicePlaybackSpeed] = useState(
|
||||
user?.preferences.voice_playback_speed ?? 1
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setDraftVoicePlaybackSpeed(user?.preferences.voice_playback_speed ?? 1);
|
||||
}, [user?.preferences.voice_playback_speed]);
|
||||
|
||||
const saveVoiceSettings = useCallback(
|
||||
async (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => {
|
||||
try {
|
||||
await updateUserVoiceSettings(settings);
|
||||
toast.success("Preferences saved");
|
||||
} catch {
|
||||
toast.error("Failed to save preferences");
|
||||
}
|
||||
},
|
||||
[updateUserVoiceSettings]
|
||||
);
|
||||
|
||||
const commitVoicePlaybackSpeed = useCallback(() => {
|
||||
const currentSpeed = user?.preferences.voice_playback_speed ?? 1;
|
||||
if (Math.abs(currentSpeed - draftVoicePlaybackSpeed) < 0.001) {
|
||||
return;
|
||||
}
|
||||
void saveVoiceSettings({
|
||||
playback_speed: draftVoicePlaybackSpeed,
|
||||
});
|
||||
}, [
|
||||
draftVoicePlaybackSpeed,
|
||||
saveVoiceSettings,
|
||||
user?.preferences.voice_playback_speed,
|
||||
]);
|
||||
|
||||
// Wrapper to save memories and return success/failure
|
||||
const handleSaveMemories = useCallback(
|
||||
@@ -936,6 +974,69 @@ function ChatPreferencesSettings() {
|
||||
{user?.preferences?.shortcut_enabled && <PromptShortcuts />}
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Voice"
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void saveVoiceSettings({ auto_send: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Playback"
|
||||
description="Automatically play voice responses."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_playback ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void saveVoiceSettings({ auto_playback: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Playback Speed"
|
||||
description="Adjust the speed of voice playback."
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<input
|
||||
type="range"
|
||||
min="0.5"
|
||||
max="2"
|
||||
step="0.1"
|
||||
value={draftVoicePlaybackSpeed}
|
||||
onChange={(e) => {
|
||||
setDraftVoicePlaybackSpeed(parseFloat(e.target.value));
|
||||
}}
|
||||
onMouseUp={commitVoicePlaybackSpeed}
|
||||
onTouchEnd={commitVoicePlaybackSpeed}
|
||||
onKeyUp={(e) => {
|
||||
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
|
||||
commitVoicePlaybackSpeed();
|
||||
}
|
||||
}}
|
||||
className="w-24 h-2 rounded-lg appearance-none cursor-pointer bg-background-neutral-02"
|
||||
/>
|
||||
<span className="text-sm text-text-02 w-10">
|
||||
{draftVoicePlaybackSpeed.toFixed(1)}x
|
||||
</span>
|
||||
</div>
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
</Section>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -55,6 +55,10 @@ import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import MicrophoneButton from "@/sections/input/MicrophoneButton";
|
||||
import Waveform from "@/components/voice/Waveform";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { useVoiceStatus } from "@/hooks/useVoiceStatus";
|
||||
|
||||
const MIN_INPUT_HEIGHT = 44;
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
@@ -113,6 +117,14 @@ const AppInputBar = React.memo(
|
||||
}: AppInputBarProps) => {
|
||||
// Internal message state - kept local to avoid parent re-renders on every keystroke
|
||||
const [message, setMessage] = useState(initialMessage);
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [recordingCycleCount, setRecordingCycleCount] = useState(0);
|
||||
const [isMuted, setIsMuted] = useState(false);
|
||||
const [audioLevel, setAudioLevel] = useState(0);
|
||||
const stopRecordingRef = useRef<(() => Promise<string | null>) | null>(
|
||||
null
|
||||
);
|
||||
const setMutedRef = useRef<((muted: boolean) => void) | null>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesWrapperRef = useRef<HTMLDivElement>(null);
|
||||
@@ -123,6 +135,38 @@ const AppInputBar = React.memo(
|
||||
const isClassifying = state.phase === "classifying";
|
||||
const isSearchActive =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
const {
|
||||
stopTTS,
|
||||
isTTSPlaying,
|
||||
isManualTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isTTSMuted,
|
||||
toggleTTSMute,
|
||||
} = useVoiceMode();
|
||||
const { sttEnabled } = useVoiceStatus();
|
||||
const isVoicePlaybackActive =
|
||||
isTTSPlaying || isTTSLoading || isAwaitingAutoPlaybackStart;
|
||||
const isVoicePlaybackControllable = isVoicePlaybackActive && !isRecording;
|
||||
const isTTSActuallySpeaking = isTTSPlaying || isManualTTSPlaying;
|
||||
|
||||
const handleRecordingChange = useCallback((nextIsRecording: boolean) => {
|
||||
setIsRecording((prevIsRecording) => {
|
||||
if (!prevIsRecording && nextIsRecording) {
|
||||
setRecordingCycleCount((count) => count + 1);
|
||||
}
|
||||
return nextIsRecording;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Wrapper for onSubmit that stops TTS first to prevent overlapping voices
|
||||
const handleSubmit = useCallback(
|
||||
(text: string) => {
|
||||
stopTTS();
|
||||
onSubmit(text);
|
||||
},
|
||||
[stopTTS, onSubmit]
|
||||
);
|
||||
|
||||
// Expose reset and focus methods to parent via ref
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
@@ -143,9 +187,14 @@ const AppInputBar = React.memo(
|
||||
}
|
||||
}, [initialMessage]);
|
||||
const appFocus = useAppFocus();
|
||||
const isNewSession = appFocus.isNewSession();
|
||||
const appMode = state.phase === "idle" ? state.appMode : undefined;
|
||||
const isSearchMode =
|
||||
(appFocus.isNewSession() && appMode === "search") || isSearchActive;
|
||||
(isNewSession && appMode === "search") || isSearchActive;
|
||||
const shouldShowRecordingWaveformBelow =
|
||||
isRecording &&
|
||||
!isVoicePlaybackActive &&
|
||||
(isNewSession || recordingCycleCount === 1);
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useForcedTools();
|
||||
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =
|
||||
@@ -558,9 +607,35 @@ const AppInputBar = React.memo(
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
{sttEnabled && (
|
||||
<MicrophoneButton
|
||||
onTranscription={(text) => setMessage(text)}
|
||||
disabled={disabled || chatState === "streaming"}
|
||||
autoSend={user?.preferences?.voice_auto_send ?? false}
|
||||
autoListen={user?.preferences?.voice_auto_playback ?? false}
|
||||
isNewSession={isNewSession}
|
||||
chatState={chatState}
|
||||
onRecordingChange={handleRecordingChange}
|
||||
stopRecordingRef={stopRecordingRef}
|
||||
onRecordingStart={() => setMessage("")}
|
||||
onAutoSend={(text) => {
|
||||
// Guard against empty transcription
|
||||
if (text.trim()) {
|
||||
handleSubmit(text);
|
||||
setMessage("");
|
||||
}
|
||||
}}
|
||||
onMuteChange={setIsMuted}
|
||||
setMutedRef={setMutedRef}
|
||||
onAudioLevel={setAudioLevel}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Disabled
|
||||
disabled={
|
||||
(chatState === "input" && !message) ||
|
||||
(chatState === "input" &&
|
||||
!isVoicePlaybackControllable &&
|
||||
!message) ||
|
||||
hasUploadingFiles ||
|
||||
isClassifying
|
||||
}
|
||||
@@ -570,13 +645,16 @@ const AppInputBar = React.memo(
|
||||
icon={
|
||||
isClassifying
|
||||
? SimpleLoader
|
||||
: chatState === "input"
|
||||
? SvgArrowUp
|
||||
: SvgStop
|
||||
: chatState === "streaming" || isVoicePlaybackControllable
|
||||
? SvgStop
|
||||
: SvgArrowUp
|
||||
}
|
||||
onClick={() => {
|
||||
if (chatState == "streaming") {
|
||||
stopTTS({ manual: true });
|
||||
stopGenerating();
|
||||
} else if (isVoicePlaybackControllable) {
|
||||
stopTTS({ manual: true });
|
||||
} else if (message) {
|
||||
onSubmit(message);
|
||||
}
|
||||
@@ -606,6 +684,32 @@ const AppInputBar = React.memo(
|
||||
// modes. See the corresponding note there for details.
|
||||
)}
|
||||
>
|
||||
{/* Voice waveform above input */}
|
||||
{isTTSActuallySpeaking ? (
|
||||
<div className="flex justify-start px-1">
|
||||
<Waveform
|
||||
variant="speaking"
|
||||
isActive={isTTSActuallySpeaking}
|
||||
isMuted={isTTSMuted}
|
||||
onMuteToggle={toggleTTSMute}
|
||||
/>
|
||||
</div>
|
||||
) : isRecording &&
|
||||
!isVoicePlaybackActive &&
|
||||
!shouldShowRecordingWaveformBelow ? (
|
||||
<div className="px-1">
|
||||
<Waveform
|
||||
variant="recording"
|
||||
isActive={isRecording}
|
||||
isMuted={isMuted}
|
||||
audioLevel={audioLevel}
|
||||
onMuteToggle={() => {
|
||||
setMutedRef.current?.(!isMuted);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Attached Files */}
|
||||
<div
|
||||
ref={filesWrapperRef}
|
||||
@@ -657,9 +761,13 @@ const AppInputBar = React.memo(
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
aria-multiline={true}
|
||||
placeholder={
|
||||
isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
isRecording
|
||||
? "Listening..."
|
||||
: isVoicePlaybackActive
|
||||
? "Onyx is speaking..."
|
||||
: isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
@@ -676,7 +784,7 @@ const AppInputBar = React.memo(
|
||||
!isClassifying &&
|
||||
!hasUploadingFiles
|
||||
) {
|
||||
onSubmit(message);
|
||||
handleSubmit(message);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -743,7 +851,7 @@ const AppInputBar = React.memo(
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
onSubmit(message);
|
||||
handleSubmit(message);
|
||||
}
|
||||
}}
|
||||
prominence="tertiary"
|
||||
@@ -755,6 +863,21 @@ const AppInputBar = React.memo(
|
||||
</div>
|
||||
|
||||
{chatControls}
|
||||
|
||||
{/* First recording cycle waveform below input */}
|
||||
{shouldShowRecordingWaveformBelow && (
|
||||
<div className="px-1">
|
||||
<Waveform
|
||||
variant="recording"
|
||||
isActive={isRecording}
|
||||
isMuted={isMuted}
|
||||
audioLevel={audioLevel}
|
||||
onMuteToggle={() => {
|
||||
setMutedRef.current?.(!isMuted);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
|
||||
297
web/src/sections/input/MicrophoneButton.tsx
Normal file
297
web/src/sections/input/MicrophoneButton.tsx
Normal file
@@ -0,0 +1,297 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgMicrophone } from "@opal/icons";
|
||||
import { useVoiceRecorder } from "@/hooks/useVoiceRecorder";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
|
||||
interface MicrophoneButtonProps {
|
||||
onTranscription: (text: string) => void;
|
||||
disabled?: boolean;
|
||||
autoSend?: boolean;
|
||||
/** Called with transcribed text when autoSend is enabled */
|
||||
onAutoSend?: (text: string) => void;
|
||||
/**
|
||||
* Internal prop: auto-start listening when TTS finishes or chat response completes.
|
||||
* Tied to voice_auto_playback user preference.
|
||||
* Enables conversation flow: speak → AI responds → auto-listen again.
|
||||
* Note: autoSend is separate - it controls whether message auto-submits after recording.
|
||||
*/
|
||||
autoListen?: boolean;
|
||||
/** Current chat state - used to detect when response streaming finishes */
|
||||
chatState?: ChatState;
|
||||
/** Called when recording state changes */
|
||||
onRecordingChange?: (isRecording: boolean) => void;
|
||||
/** Ref to expose stop recording function to parent */
|
||||
stopRecordingRef?: React.MutableRefObject<
|
||||
(() => Promise<string | null>) | null
|
||||
>;
|
||||
/** Called when recording starts to clear input */
|
||||
onRecordingStart?: () => void;
|
||||
/** Called when mute state changes */
|
||||
onMuteChange?: (isMuted: boolean) => void;
|
||||
/** Ref to expose setMuted function to parent */
|
||||
setMutedRef?: React.MutableRefObject<((muted: boolean) => void) | null>;
|
||||
/** Called with current microphone audio level (0-1) for waveform visualization */
|
||||
onAudioLevel?: (level: number) => void;
|
||||
/** Whether current chat is a new session (used to reset auto-listen arming) */
|
||||
isNewSession?: boolean;
|
||||
}
|
||||
|
||||
function MicrophoneButton({
|
||||
onTranscription,
|
||||
disabled = false,
|
||||
autoSend = false,
|
||||
onAutoSend,
|
||||
autoListen = false,
|
||||
chatState,
|
||||
onRecordingChange,
|
||||
stopRecordingRef,
|
||||
onRecordingStart,
|
||||
onMuteChange,
|
||||
setMutedRef,
|
||||
onAudioLevel,
|
||||
isNewSession = false,
|
||||
}: MicrophoneButtonProps) {
|
||||
const {
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
manualStopCount,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Refs for tracking state across renders
|
||||
// Track whether TTS was actually playing audio (not just loading)
|
||||
const wasTTSActuallyPlayingRef = useRef(false);
|
||||
const manualStopRequestedRef = useRef(false);
|
||||
const lastHandledManualStopCountRef = useRef(manualStopCount);
|
||||
const autoListenCooldownTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const hasManualRecordStartRef = useRef(false);
|
||||
|
||||
// Handler for VAD (Voice Activity Detection) triggered auto-send.
|
||||
// VAD runs server-side in the STT provider and detects when the user stops speaking.
|
||||
const handleFinalTranscript = useCallback(
|
||||
(text: string) => {
|
||||
onTranscription(text);
|
||||
const isManualStop = manualStopRequestedRef.current;
|
||||
// Only auto-send if chat is ready for input (not streaming)
|
||||
if (!isManualStop && autoSend && onAutoSend && chatState === "input") {
|
||||
onAutoSend(text);
|
||||
}
|
||||
},
|
||||
[onTranscription, autoSend, onAutoSend, chatState]
|
||||
);
|
||||
|
||||
const {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
audioLevel,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
if (stopRecordingRef) {
|
||||
stopRecordingRef.current = stopRecording;
|
||||
}
|
||||
}, [stopRecording, stopRecordingRef]);
|
||||
|
||||
// Expose setMuted to parent
|
||||
useEffect(() => {
|
||||
if (setMutedRef) {
|
||||
setMutedRef.current = setMuted;
|
||||
}
|
||||
}, [setMuted, setMutedRef]);
|
||||
|
||||
// Notify parent when mute state changes
|
||||
useEffect(() => {
|
||||
onMuteChange?.(isMuted);
|
||||
}, [isMuted, onMuteChange]);
|
||||
|
||||
// Forward audio level to parent for waveform visualization
|
||||
useEffect(() => {
|
||||
onAudioLevel?.(audioLevel);
|
||||
}, [audioLevel, onAudioLevel]);
|
||||
|
||||
// Notify parent when recording state changes
|
||||
useEffect(() => {
|
||||
onRecordingChange?.(isRecording);
|
||||
}, [isRecording, onRecordingChange]);
|
||||
|
||||
// Update input with live transcript as user speaks
|
||||
useEffect(() => {
|
||||
if (isRecording && liveTranscript) {
|
||||
onTranscription(liveTranscript);
|
||||
}
|
||||
}, [isRecording, liveTranscript, onTranscription]);
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isRecording) {
|
||||
// When recording, clicking the mic button stops recording
|
||||
manualStopRequestedRef.current = true;
|
||||
try {
|
||||
const finalTranscript = await stopRecording();
|
||||
if (finalTranscript) {
|
||||
onTranscription(finalTranscript);
|
||||
}
|
||||
if (
|
||||
autoSend &&
|
||||
onAutoSend &&
|
||||
chatState === "input" &&
|
||||
finalTranscript?.trim()
|
||||
) {
|
||||
onAutoSend(finalTranscript);
|
||||
}
|
||||
} finally {
|
||||
manualStopRequestedRef.current = false;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
// Clear input before starting recording
|
||||
onRecordingStart?.();
|
||||
await startRecording();
|
||||
// Arm auto-listen only after first manual mic start in this session.
|
||||
hasManualRecordStartRef.current = true;
|
||||
} catch (err) {
|
||||
console.error("Microphone access failed:", err);
|
||||
toast.error("Could not access microphone");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isRecording,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
onRecordingStart,
|
||||
onTranscription,
|
||||
autoSend,
|
||||
onAutoSend,
|
||||
chatState,
|
||||
]);
|
||||
|
||||
// Auto-start listening shortly after TTS finishes (only if autoListen is enabled).
|
||||
// Small cooldown reduces playback bleed being re-captured by the microphone.
|
||||
// IMPORTANT: Only trigger auto-listen if TTS was actually playing audio,
|
||||
// not just loading. This prevents auto-listen from triggering when TTS fails.
|
||||
useEffect(() => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
|
||||
const stoppedManually =
|
||||
manualStopCount !== lastHandledManualStopCountRef.current;
|
||||
|
||||
// Only trigger auto-listen if TTS was actually playing (not just loading)
|
||||
if (
|
||||
wasTTSActuallyPlayingRef.current &&
|
||||
!isTTSPlaying &&
|
||||
!isTTSLoading &&
|
||||
!isAwaitingAutoPlaybackStart &&
|
||||
autoListen &&
|
||||
hasManualRecordStartRef.current &&
|
||||
!disabled &&
|
||||
!isRecording &&
|
||||
!stoppedManually
|
||||
) {
|
||||
autoListenCooldownTimerRef.current = setTimeout(() => {
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
if (
|
||||
!autoListen ||
|
||||
disabled ||
|
||||
isRecording ||
|
||||
isTTSPlaying ||
|
||||
isTTSLoading ||
|
||||
isAwaitingAutoPlaybackStart
|
||||
) {
|
||||
return;
|
||||
}
|
||||
startRecording().catch((err) => {
|
||||
console.error("Auto-start microphone failed:", err);
|
||||
toast.error("Could not auto-start microphone");
|
||||
});
|
||||
}, 400);
|
||||
}
|
||||
|
||||
if (stoppedManually) {
|
||||
lastHandledManualStopCountRef.current = manualStopCount;
|
||||
}
|
||||
|
||||
// Only track actual playback - not loading states
|
||||
// This ensures auto-listen only triggers after audio actually played
|
||||
if (isTTSPlaying) {
|
||||
wasTTSActuallyPlayingRef.current = true;
|
||||
} else if (!isTTSPlaying && !isTTSLoading && !isAwaitingAutoPlaybackStart) {
|
||||
// Reset when TTS is completely done
|
||||
wasTTSActuallyPlayingRef.current = false;
|
||||
}
|
||||
}, [
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
autoListen,
|
||||
disabled,
|
||||
isRecording,
|
||||
startRecording,
|
||||
manualStopCount,
|
||||
]);
|
||||
|
||||
// New sessions must start with an explicit manual mic press.
|
||||
useEffect(() => {
|
||||
if (isNewSession) {
|
||||
hasManualRecordStartRef.current = false;
|
||||
}
|
||||
}, [isNewSession]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice recorder error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
// Icon: show loader when processing, otherwise mic
|
||||
const icon = isProcessing ? SimpleLoader : SvgMicrophone;
|
||||
|
||||
// Disable when processing or TTS is playing (don't want to pick up TTS audio)
|
||||
const isDisabled =
|
||||
disabled ||
|
||||
isProcessing ||
|
||||
isTTSPlaying ||
|
||||
isTTSLoading ||
|
||||
isAwaitingAutoPlaybackStart;
|
||||
|
||||
// Recording = darkened (primary), not recording = light (tertiary)
|
||||
const prominence = isRecording ? "primary" : "tertiary";
|
||||
|
||||
return (
|
||||
<Disabled disabled={isDisabled}>
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
prominence={prominence}
|
||||
/>
|
||||
</Disabled>
|
||||
);
|
||||
}
|
||||
|
||||
export default MicrophoneButton;
|
||||
@@ -19,7 +19,35 @@ import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidE
|
||||
import { CombinedSettings } from "@/interfaces/settings";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import SidebarBody from "@/sections/sidebar/SidebarBody";
|
||||
import { SvgArrowUpCircle } from "@opal/icons";
|
||||
import {
|
||||
SvgActions,
|
||||
SvgActivity,
|
||||
SvgArrowUpCircle,
|
||||
SvgBarChart,
|
||||
SvgCpu,
|
||||
SvgFileText,
|
||||
SvgFolder,
|
||||
SvgGlobe,
|
||||
SvgArrowExchange,
|
||||
SvgImage,
|
||||
SvgKey,
|
||||
SvgOnyxLogo,
|
||||
SvgOnyxOctagon,
|
||||
SvgSearch,
|
||||
SvgServer,
|
||||
SvgSettings,
|
||||
SvgShield,
|
||||
SvgThumbsUp,
|
||||
SvgUploadCloud,
|
||||
SvgUser,
|
||||
SvgUsers,
|
||||
SvgZoomIn,
|
||||
SvgPaintBrush,
|
||||
SvgDiscordMono,
|
||||
SvgWallet,
|
||||
SvgAudio,
|
||||
} from "@opal/icons";
|
||||
import SvgMcp from "@opal/icons/mcp";
|
||||
import { ADMIN_PATHS, sidebarItem } from "@/lib/admin-routes";
|
||||
import UserAvatarPopover from "@/sections/sidebar/UserAvatarPopover";
|
||||
|
||||
@@ -105,6 +133,11 @@ const collections = (
|
||||
sidebarItem(ADMIN_PATHS.LLM_MODELS),
|
||||
sidebarItem(ADMIN_PATHS.WEB_SEARCH),
|
||||
sidebarItem(ADMIN_PATHS.IMAGE_GENERATION),
|
||||
{
|
||||
name: "Voice",
|
||||
icon: SvgAudio,
|
||||
link: "/admin/configuration/voice",
|
||||
},
|
||||
sidebarItem(ADMIN_PATHS.CODE_INTERPRETER),
|
||||
...(!enableCloud && vectorDbEnabled
|
||||
? [
|
||||
|
||||
Reference in New Issue
Block a user