mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-12 03:02:43 +00:00
Compare commits
1 Commits
voice-mode
...
jamison/Co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1e26b1ae1 |
@@ -1,117 +0,0 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: b5c4d7e8f9a1
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "b5c4d7e8f9a1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -29,7 +29,6 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -122,7 +121,6 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -1614,102 +1612,6 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -343,11 +343,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -3060,65 +3055,6 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
MIN_VOICE_PLAYBACK_SPEED = 0.5
|
||||
MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For all fields, None means "don't update this field".
|
||||
"""
|
||||
values: dict[str, bool | float] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
|
||||
)
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -66,11 +66,6 @@ class OnyxErrorCode(Enum):
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Payload (413)
|
||||
# ------------------------------------------------------------------
|
||||
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -123,11 +123,15 @@ class DocumentIndexingBatchAdapter:
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
document_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
)
|
||||
for document_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.pydantic_util import shallow_model_dump
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -211,8 +210,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
)[0]
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk.model_construct(
|
||||
**shallow_model_dump(chunk),
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.connectors.models import Document
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.enums import SwitchoverType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.pydantic_util import shallow_model_dump
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -134,8 +133,9 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
tenant_id: str,
|
||||
ancestor_hierarchy_node_ids: list[int] | None = None,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
return cls.model_construct(
|
||||
**shallow_model_dump(index_chunk),
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
**index_chunk_data,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
|
||||
@@ -43,7 +43,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
]
|
||||
|
||||
|
||||
@@ -60,7 +59,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -111,7 +109,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -11,8 +11,6 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
@@ -48,7 +47,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
}
|
||||
|
||||
|
||||
@@ -333,7 +331,6 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -119,9 +119,6 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -500,9 +497,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -419,15 +419,12 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -442,96 +439,12 @@ async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
# Rate limit: max tokens per user per window
|
||||
WS_TOKEN_RATE_LIMIT_MAX = 10
|
||||
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
|
||||
|
||||
|
||||
class WsTokenRateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds the WS token generation rate limit."""
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
|
||||
Raises:
|
||||
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# Atomically increment and check rate limit to avoid TOCTOU races
|
||||
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_limit_key)
|
||||
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
|
||||
results = await pipe.execute()
|
||||
new_count = results[0]
|
||||
|
||||
if new_count > WS_TOKEN_RATE_LIMIT_MAX:
|
||||
# Over limit — decrement back since we won't use this slot
|
||||
await redis.decr(rate_limit_key)
|
||||
logger.warning(f"WS token rate limit exceeded for user {user_id}")
|
||||
raise WsTokenRateLimitExceeded(
|
||||
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
|
||||
# Store the actual token
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -130,7 +129,6 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -732,7 +732,7 @@ def get_webapp_info(
|
||||
return WebappInfo(**webapp_info)
|
||||
|
||||
|
||||
@router.get("/{session_id}/webapp-download")
|
||||
@router.get("/{session_id}/webapp/download")
|
||||
def download_webapp(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -85,11 +85,6 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -169,9 +164,6 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -248,12 +240,6 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -1,315 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.models import VoiceOption
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
|
||||
"""Validate and normalize provider api_base / target URI."""
|
||||
if api_base is None:
|
||||
return None
|
||||
|
||||
allow_private_network = provider_type.lower() == "azure"
|
||||
try:
|
||||
return validate_outbound_http_url(
|
||||
api_base, allow_private_network=allow_private_network
|
||||
)
|
||||
except (ValueError, SSRFException) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid target URI: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
async def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
# Validate credentials before committing - rollback on failure
|
||||
try:
|
||||
voice_provider = get_voice_provider(provider)
|
||||
await voice_provider.validate_credentials()
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Voice provider credential validation failed on save: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connection test failed. Please verify your API key and settings.",
|
||||
) from e
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
async def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Test a voice provider connection by making a real API call."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
# Validate credentials with a real API call
|
||||
try:
|
||||
await provider.validate_credentials()
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Voice provider connection test failed: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connection test failed. Please verify your API key and settings.",
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
@@ -1,95 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpdateSuccess(BaseModel):
|
||||
"""Simple status response for voice provider actions."""
|
||||
|
||||
status: str = "ok"
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""Voice option returned by voice providers."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
@@ -1,250 +0,0 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
# Chunk size for streaming uploads (8KB)
|
||||
UPLOAD_READ_CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
class VoiceStatusResponse(BaseModel):
|
||||
stt_enabled: bool
|
||||
tts_enabled: bool
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_voice_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceStatusResponse:
|
||||
"""Check whether STT and TTS providers are configured and ready."""
|
||||
stt_provider = fetch_default_stt_provider(db_session)
|
||||
tts_provider = fetch_default_tts_provider(db_session)
|
||||
return VoiceStatusResponse(
|
||||
stt_enabled=stt_provider is not None and stt_provider.api_key is not None,
|
||||
tts_enabled=tts_provider is not None and tts_provider.api_key is not None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
|
||||
total += len(chunk)
|
||||
if total > MAX_AUDIO_SIZE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
audio_data = b"".join(chunks)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Transcription failed. Please try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice or provider default
|
||||
final_voice = voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
Rate limited to 10 tokens per minute per user.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
await store_ws_token(token, str(user.id))
|
||||
except WsTokenRateLimitExceeded:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
"Too many token requests. Please wait before requesting another.",
|
||||
)
|
||||
return WSTokenResponse(token=token)
|
||||
@@ -1,860 +0,0 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
# WebSocket size limits to prevent memory exhaustion attacks
|
||||
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
|
||||
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
|
||||
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
|
||||
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def shallow_model_dump(model_instance: BaseModel) -> dict[str, Any]:
|
||||
"""Like model_dump(), but returns references to field values instead of
|
||||
deep copies. Use with model_construct() to avoid unnecessary memory
|
||||
duplication when building subclass instances."""
|
||||
return {
|
||||
field_name: getattr(model_instance, field_name)
|
||||
for field_name in model_instance.__class__.model_fields
|
||||
}
|
||||
@@ -140,44 +140,6 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed.
|
||||
SSRFException: If URL fails SSRF checks.
|
||||
"""
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must contain a hostname")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFException("URLs with embedded credentials are not allowed.")
|
||||
|
||||
hostname = parsed.hostname.lower()
|
||||
if hostname in BLOCKED_HOSTNAMES:
|
||||
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
|
||||
|
||||
if not allow_private_network:
|
||||
_validate_and_resolve_url(normalized_url)
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
@@ -1,182 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_credentials(self) -> None:
|
||||
"""
|
||||
Validate that the provider credentials are correct by making a
|
||||
lightweight API call. Raises on failure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
@@ -1,626 +0,0 @@
|
||||
"""Azure Speech Services voice provider for STT and TTS.
|
||||
|
||||
Azure supports:
|
||||
- **STT**: Batch transcription via REST API (audio/wav POST) and real-time
|
||||
streaming via the Azure Speech SDK (push audio stream with continuous
|
||||
recognition). The SDK handles VAD natively through its recognizing/recognized
|
||||
events.
|
||||
- **TTS**: SSML-based synthesis via REST API (streaming response) and real-time
|
||||
synthesis via the Speech SDK. Text is escaped with ``xml.sax.saxutils.escape``
|
||||
and attributes with ``quoteattr`` to prevent SSML injection.
|
||||
|
||||
Both modes support Azure cloud endpoints (region-based URLs) and self-hosted
|
||||
Speech containers (custom endpoint URLs). The ``speech_region`` is validated to
|
||||
contain only ``[a-z0-9-]`` to prevent URL injection.
|
||||
|
||||
The Azure Speech SDK (``azure-cognitiveservices-speech``) is an optional C
|
||||
extension dependency — it is imported lazily inside streaming methods so the
|
||||
provider can still be instantiated and used for REST-based operations without it.
|
||||
|
||||
See https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/
|
||||
for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# SSML namespace — W3C standard for Speech Synthesis Markup Language.
|
||||
# This is a fixed W3C specification and will not change.
|
||||
SSML_NAMESPACE = "http://www.w3.org/2001/10/synthesis"
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
raw_speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.speech_region = self._validate_speech_region(raw_speech_region)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_cloud_url(uri: str | None) -> bool:
|
||||
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
|
||||
if not uri:
|
||||
return False
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return False
|
||||
return hostname.endswith(
|
||||
(
|
||||
".speech.microsoft.com",
|
||||
".api.cognitive.microsoft.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stt_tts_match = re.match(
|
||||
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
|
||||
)
|
||||
if stt_tts_match:
|
||||
return stt_tts_match.group(1)
|
||||
|
||||
api_match = re.match(
|
||||
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
|
||||
)
|
||||
if api_match:
|
||||
return api_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_speech_region(speech_region: str) -> str:
|
||||
normalized_region = speech_region.strip().lower()
|
||||
if not normalized_region:
|
||||
return ""
|
||||
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
|
||||
raise ValueError(
|
||||
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
|
||||
)
|
||||
return normalized_region
|
||||
|
||||
def _get_stt_url(self) -> str:
|
||||
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
def _get_tts_url(self) -> str:
|
||||
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
def _is_self_hosted(self) -> bool:
|
||||
"""Check if using self-hosted container vs Azure cloud."""
|
||||
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT (cloud mode)")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = self._get_stt_url()
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS (cloud mode)")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = self._get_tts_url()
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate Azure credentials by listing available voices."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required (cloud mode)")
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
if self._is_self_hosted():
|
||||
url = f"{(self.api_base or '').rstrip('/')}/cognitiveservices/voices/list"
|
||||
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"Azure credential validation failed: {error_text}"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming transcription (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming TTS (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,858 +0,0 @@
|
||||
"""ElevenLabs voice provider for STT and TTS.
|
||||
|
||||
ElevenLabs supports:
|
||||
- **STT**: Scribe API (batch via REST, streaming via WebSocket with Scribe v2 Realtime).
|
||||
The streaming endpoint sends base64-encoded PCM16 audio chunks and receives JSON
|
||||
transcript messages (partial_transcript, committed_transcript, utterance_end).
|
||||
- **TTS**: Text-to-speech via REST streaming and WebSocket stream-input.
|
||||
The WebSocket variant accepts incremental text chunks and returns audio in order,
|
||||
enabling low-latency playback before the full text is available.
|
||||
|
||||
See https://elevenlabs.io/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Default ElevenLabs API base URL
|
||||
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
|
||||
|
||||
# Default sample rates for STT streaming
|
||||
DEFAULT_INPUT_SAMPLE_RATE = 24000 # What the browser frontend sends
|
||||
DEFAULT_TARGET_SAMPLE_RATE = 16000 # What ElevenLabs Scribe expects
|
||||
|
||||
# Default streaming TTS output format
|
||||
DEFAULT_TTS_OUTPUT_FORMAT = "mp3_44100_64"
|
||||
|
||||
# Default TTS voice settings
|
||||
DEFAULT_VOICE_STABILITY = 0.5
|
||||
DEFAULT_VOICE_SIMILARITY_BOOST = 0.75
|
||||
|
||||
# Chunk length schedule for streaming TTS (optimized for real-time playback)
|
||||
DEFAULT_CHUNK_LENGTH_SCHEDULE = [120, 160, 250, 290]
|
||||
|
||||
# Default STT streaming VAD configuration
|
||||
DEFAULT_VAD_SILENCE_THRESHOLD_SECS = 1.0
|
||||
DEFAULT_VAD_THRESHOLD = 0.4
|
||||
DEFAULT_MIN_SPEECH_DURATION_MS = 100
|
||||
DEFAULT_MIN_SILENCE_DURATION_MS = 300
|
||||
|
||||
|
||||
class ElevenLabsSTTMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs Scribe Realtime STT API."""
|
||||
|
||||
SESSION_STARTED = "session_started"
|
||||
PARTIAL_TRANSCRIPT = "partial_transcript"
|
||||
COMMITTED_TRANSCRIPT = "committed_transcript"
|
||||
UTTERANCE_END = "utterance_end"
|
||||
SESSION_ENDED = "session_ended"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ElevenLabsTTSMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs stream-input TTS API."""
|
||||
|
||||
AUDIO = "audio"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate: int = DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code: str = "en",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters.
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection.
|
||||
# These params are part of the ElevenLabs Scribe Realtime API contract:
|
||||
# https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs={DEFAULT_VAD_SILENCE_THRESHOLD_SECS}"
|
||||
f"&vad_threshold={DEFAULT_VAD_THRESHOLD}"
|
||||
f"&min_speech_duration_ms={DEFAULT_MIN_SPEECH_DURATION_MS}"
|
||||
f"&min_silence_duration_ms={DEFAULT_MIN_SILENCE_DURATION_MS}"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == ElevenLabsSTTMessageType.ERROR:
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle message types from ElevenLabs Scribe Realtime API.
|
||||
# See https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
if msg_type == ElevenLabsSTTMessageType.SESSION_STARTED:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.PARTIAL_TRANSCRIPT:
|
||||
# Interim result — updated as more audio is processed
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT:
|
||||
# Final transcript for the current utterance (VAD detected end)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.UTTERANCE_END:
|
||||
# VAD detected end of speech (may carry text or be empty)
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.SESSION_ENDED:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
api_base: str | None = None,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self.speed = speed
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming.
|
||||
# Note: API key is sent via header only (not in body to avoid log exposure).
|
||||
# See https://elevenlabs.io/docs/api-reference/text-to-speech/stream-input
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": self.speed,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": DEFAULT_CHUNK_LENGTH_SCHEDULE,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": speed,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate ElevenLabs API key by fetching user info."""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required")
|
||||
|
||||
headers = {"xi-api-key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_base}/v1/user", headers=headers
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(
|
||||
f"ElevenLabs credential validation failed: {error_text}"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model.
|
||||
# Frontend sends PCM16 at DEFAULT_INPUT_SAMPLE_RATE (24kHz),
|
||||
# but ElevenLabs expects DEFAULT_TARGET_SAMPLE_RATE (16kHz).
|
||||
# The transcriber resamples automatically.
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate=DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code="en",
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
output_format=DEFAULT_TTS_OUTPUT_FORMAT,
|
||||
api_base=self.api_base,
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,626 +0,0 @@
|
||||
"""OpenAI voice provider for STT and TTS.
|
||||
|
||||
OpenAI supports:
|
||||
- **STT**: Whisper (batch transcription via REST) and Realtime API (streaming
|
||||
transcription via WebSocket with server-side VAD). Audio is sent as base64-encoded
|
||||
PCM16 at 24kHz mono. The Realtime API returns transcript deltas and completed
|
||||
transcription events per VAD-detected utterance.
|
||||
- **TTS**: HTTP streaming endpoint that returns audio chunks progressively.
|
||||
Supported models: tts-1 (standard) and tts-1-hd (high quality).
|
||||
|
||||
See https://platform.openai.com/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Default OpenAI API base URL
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
|
||||
|
||||
|
||||
class OpenAIRealtimeMessageType(StrEnum):
|
||||
"""Message types from OpenAI Realtime transcription API."""
|
||||
|
||||
ERROR = "error"
|
||||
SPEECH_STARTED = "input_audio_buffer.speech_started"
|
||||
SPEECH_STOPPED = "input_audio_buffer.speech_stopped"
|
||||
BUFFER_COMMITTED = "input_audio_buffer.committed"
|
||||
TRANSCRIPTION_DELTA = "conversation.item.input_audio_transcription.delta"
|
||||
TRANSCRIPTION_COMPLETED = "conversation.item.input_audio_transcription.completed"
|
||||
SESSION_CREATED = "transcription_session.created"
|
||||
SESSION_UPDATED = "transcription_session.updated"
|
||||
ITEM_CREATED = "conversation.item.created"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "whisper-1",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = f"{ws_base}/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == OpenAIRealtimeMessageType.ERROR:
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == OpenAIRealtimeMessageType.SPEECH_STARTED:
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.SPEECH_STOPPED:
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.BUFFER_COMMITTED:
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA:
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_COMPLETED:
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
OpenAIRealtimeMessageType.SESSION_CREATED,
|
||||
OpenAIRealtimeMessageType.SESSION_UPDATED,
|
||||
OpenAIRealtimeMessageType.ITEM_CREATED,
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for *new* transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
transcript_before_commit = self._accumulated_transcript
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript != transcript_before_commit:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate OpenAI API key by listing models."""
|
||||
client = self._get_client()
|
||||
await client.models.list()
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -67,8 +67,6 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -752,7 +750,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.8.0
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -406,7 +406,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.5.2
|
||||
release-tag==0.4.3
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
|
||||
@@ -1,507 +0,0 @@
|
||||
"""Unit tests for onyx.db.voice module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import MAX_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import MIN_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
def _make_voice_provider(
|
||||
id: int = 1,
|
||||
name: str = "Test Provider",
|
||||
provider_type: str = "openai",
|
||||
is_default_stt: bool = False,
|
||||
is_default_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create a VoiceProvider instance for testing."""
|
||||
provider = VoiceProvider()
|
||||
provider.id = id
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.is_default_stt = is_default_stt
|
||||
provider.is_default_tts = is_default_tts
|
||||
provider.api_key = None
|
||||
provider.api_base = None
|
||||
provider.custom_config = None
|
||||
provider.stt_model = None
|
||||
provider.tts_model = None
|
||||
provider.default_voice = None
|
||||
return provider
|
||||
|
||||
|
||||
class TestFetchVoiceProviders:
|
||||
"""Tests for fetch_voice_providers."""
|
||||
|
||||
def test_returns_all_providers(self, mock_db_session: MagicMock) -> None:
|
||||
providers = [
|
||||
_make_voice_provider(id=1, name="Provider A"),
|
||||
_make_voice_provider(id=2, name="Provider B"),
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = providers
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == providers
|
||||
mock_db_session.scalars.assert_called_once()
|
||||
|
||||
def test_returns_empty_list_when_no_providers(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestFetchVoiceProviderById:
|
||||
"""Tests for fetch_voice_provider_by_id."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 1)
|
||||
|
||||
assert result is provider
|
||||
mock_db_session.scalar.assert_called_once()
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchDefaultProviders:
|
||||
"""Tests for fetch_default_stt_provider and fetch_default_tts_provider."""
|
||||
|
||||
def test_fetch_default_stt_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_stt_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fetch_default_tts_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_tts_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchVoiceProviderByType:
|
||||
"""Tests for fetch_voice_provider_by_type."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1, provider_type="openai")
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "openai")
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertVoiceProvider:
|
||||
"""Tests for upsert_voice_provider."""
|
||||
|
||||
def test_creates_new_provider_when_no_id(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=None,
|
||||
name="New Provider",
|
||||
provider_type="openai",
|
||||
api_key="test-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert added_obj.name == "New Provider"
|
||||
assert added_obj.provider_type == "openai"
|
||||
|
||||
def test_updates_existing_provider(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1, name="Old Name")
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Updated Name",
|
||||
provider_type="elevenlabs",
|
||||
api_key="new-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
assert existing_provider.name == "Updated Name"
|
||||
assert existing_provider.provider_type == "elevenlabs"
|
||||
|
||||
def test_raises_when_provider_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=999,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_does_not_update_api_key_when_not_changed(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
existing_provider.api_key = "original-key" # type: ignore[assignment]
|
||||
original_api_key = existing_provider.api_key
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key="new-key",
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
# api_key should remain unchanged (same object reference)
|
||||
assert existing_provider.api_key is original_api_key
|
||||
|
||||
def test_activates_stt_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_stt=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_stt is True
|
||||
|
||||
def test_activates_tts_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_tts=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_tts is True
|
||||
|
||||
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
delete_voice_provider(mock_db_session, 999)
|
||||
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSetDefaultProviders:
|
||||
"""Tests for set_default_stt_provider and set_default_tts_provider."""
|
||||
|
||||
def test_set_default_stt_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_stt is True
|
||||
|
||||
def test_set_default_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_set_default_tts_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_tts is True
|
||||
|
||||
def test_set_default_tts_provider_updates_model_when_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(
|
||||
db_session=mock_db_session, provider_id=1, tts_model="tts-1-hd"
|
||||
)
|
||||
|
||||
assert result.tts_model == "tts-1-hd"
|
||||
|
||||
def test_set_default_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDeactivateProviders:
|
||||
"""Tests for deactivate_stt_provider and deactivate_tts_provider."""
|
||||
|
||||
def test_deactivate_stt_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_stt is False
|
||||
|
||||
def test_deactivate_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_deactivate_tts_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_tts is False
|
||||
|
||||
def test_deactivate_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestUpdateUserVoiceSettings:
|
||||
"""Tests for update_user_voice_settings."""
|
||||
|
||||
def test_updates_auto_send(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_send=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_auto_playback(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_playback=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_playback_speed_within_range(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=1.5)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
def test_clamps_playback_speed_to_min(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=0.1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MIN_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_clamps_playback_speed_to_max(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=5.0)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MAX_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_updates_multiple_settings(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(
|
||||
mock_db_session,
|
||||
user_id,
|
||||
auto_send=True,
|
||||
auto_playback=False,
|
||||
playback_speed=1.25,
|
||||
)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_no_settings_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id)
|
||||
|
||||
mock_db_session.execute.assert_not_called()
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSpeedClampingLogic:
|
||||
"""Tests for the speed clamping constants and logic."""
|
||||
|
||||
def test_min_speed_constant(self) -> None:
|
||||
assert MIN_VOICE_PLAYBACK_SPEED == 0.5
|
||||
|
||||
def test_max_speed_constant(self) -> None:
|
||||
assert MAX_VOICE_PLAYBACK_SPEED == 2.0
|
||||
|
||||
def test_clamping_formula(self) -> None:
|
||||
"""Verify the clamping formula used in update_user_voice_settings."""
|
||||
test_cases = [
|
||||
(0.1, MIN_VOICE_PLAYBACK_SPEED),
|
||||
(0.5, 0.5),
|
||||
(1.0, 1.0),
|
||||
(1.5, 1.5),
|
||||
(2.0, 2.0),
|
||||
(3.0, MAX_VOICE_PLAYBACK_SPEED),
|
||||
]
|
||||
for speed, expected in test_cases:
|
||||
clamped = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, speed)
|
||||
)
|
||||
assert (
|
||||
clamped == expected
|
||||
), f"speed={speed} expected={expected} got={clamped}"
|
||||
@@ -26,6 +26,14 @@ class TestIsTrueOpenAIModel:
|
||||
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
|
||||
|
||||
def test_real_openai_o1_preview(self) -> None:
|
||||
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
|
||||
|
||||
def test_real_openai_o1_mini(self) -> None:
|
||||
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
|
||||
|
||||
def test_openai_with_provider_prefix(self) -> None:
|
||||
"""Test that OpenAI model with provider prefix is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.api import _validate_voice_api_base
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_allows_private_for_azure() -> None:
|
||||
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
|
||||
assert validated == "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("azure", "http://metadata.google.internal/")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_returns_none_for_none() -> None:
|
||||
assert _validate_voice_api_base("openai", None) is None
|
||||
@@ -14,7 +14,6 @@ from onyx.utils.url import _is_ip_private_or_reserved
|
||||
from onyx.utils.url import _validate_and_resolve_url
|
||||
from onyx.utils.url import ssrf_safe_get
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
|
||||
class TestIsIpPrivateOrReserved:
|
||||
@@ -306,22 +305,3 @@ class TestSsrfSafeGet:
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
|
||||
class TestValidateOutboundHttpUrl:
|
||||
def test_rejects_private_ip_by_default(self) -> None:
|
||||
with pytest.raises(SSRFException, match="internal/private IP"):
|
||||
validate_outbound_http_url("http://127.0.0.1:8000")
|
||||
|
||||
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
|
||||
validated_url = validate_outbound_http_url(
|
||||
"http://127.0.0.1:8000", allow_private_network=True
|
||||
)
|
||||
assert validated_url == "http://127.0.0.1:8000"
|
||||
|
||||
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
|
||||
with pytest.raises(SSRFException, match="not allowed"):
|
||||
validate_outbound_http_url(
|
||||
"http://metadata.google.internal/latest",
|
||||
allow_private_network=True,
|
||||
)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
def test_azure_provider_extracts_region_from_target_uri() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://westus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider.speech_region == "westus"
|
||||
|
||||
|
||||
def test_azure_provider_normalizes_uppercase_region() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "WestUS2"},
|
||||
)
|
||||
assert provider.speech_region == "westus2"
|
||||
|
||||
|
||||
def test_azure_provider_rejects_invalid_speech_region() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "westus/../../etc"},
|
||||
)
|
||||
@@ -1,194 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
# --- _is_azure_cloud_url ---
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_speech_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_cognitive_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://westus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_custom_host() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url("https://my-custom-host.com/")
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_none() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url(None)
|
||||
|
||||
|
||||
# --- _extract_speech_region_from_uri ---
|
||||
|
||||
|
||||
def test_extract_region_from_tts_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_from_cognitive_api_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_custom_domain() -> None:
|
||||
"""Custom domains use resource name, not region — must use speech_region config."""
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://myresource.cognitiveservices.azure.com/"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_none() -> None:
|
||||
assert AzureVoiceProvider._extract_speech_region_from_uri(None) is None
|
||||
|
||||
|
||||
# --- _validate_speech_region ---
|
||||
|
||||
|
||||
def test_validate_region_normalizes_to_lowercase() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("WestUS2") == "westus2"
|
||||
|
||||
|
||||
def test_validate_region_accepts_hyphens() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("us-east-1") == "us-east-1"
|
||||
|
||||
|
||||
def test_validate_region_rejects_path_traversal() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("westus/../../etc")
|
||||
|
||||
|
||||
def test_validate_region_rejects_dots() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("west.us")
|
||||
|
||||
|
||||
# --- _pcm16_to_wav ---
|
||||
|
||||
|
||||
def test_pcm16_to_wav_produces_valid_wav() -> None:
|
||||
samples = [32767, -32768, 0, 1234]
|
||||
pcm_data = struct.pack(f"<{len(samples)}h", *samples)
|
||||
wav_bytes = AzureVoiceProvider._pcm16_to_wav(pcm_data, sample_rate=16000)
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 16000
|
||||
frames = wav_file.readframes(4)
|
||||
recovered = struct.unpack(f"<{len(samples)}h", frames)
|
||||
assert list(recovered) == samples
|
||||
|
||||
|
||||
# --- URL Construction ---
|
||||
|
||||
|
||||
def test_get_tts_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "eastus"}
|
||||
)
|
||||
assert (
|
||||
provider._get_tts_url()
|
||||
== "https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_get_stt_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "westus2"}
|
||||
)
|
||||
assert "westus2.stt.speech.microsoft.com" in provider._get_stt_url()
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted_strips_trailing_slash() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000/", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
# --- _is_self_hosted ---
|
||||
|
||||
|
||||
def test_is_self_hosted_true_for_custom_endpoint() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._is_self_hosted() is True
|
||||
|
||||
|
||||
def test_is_self_hosted_false_for_azure_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://eastus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider._is_self_hosted() is False
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
assert len(result) // 2 == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_empty_data() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
assert t._resample_pcm16(b"") == b""
|
||||
@@ -1,117 +0,0 @@
|
||||
import struct
|
||||
|
||||
from onyx.voice.providers.elevenlabs import _http_to_ws_url
|
||||
from onyx.voice.providers.elevenlabs import DEFAULT_ELEVENLABS_API_BASE
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsSTTMessageType
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.elevenlabs.io") == "wss://api.elevenlabs.io"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:8080") == "ws://localhost:8080"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_other_schemes() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
def test_http_to_ws_url_preserves_path() -> None:
|
||||
assert (
|
||||
_http_to_ws_url("https://api.elevenlabs.io/v1/tts")
|
||||
== "wss://api.elevenlabs.io/v1/tts"
|
||||
)
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_stt_message_type_compares_as_string() -> None:
|
||||
"""StrEnum members should work in string comparisons (e.g. from JSON)."""
|
||||
assert str(ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT) == "committed_transcript"
|
||||
assert isinstance(ElevenLabsSTTMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough_when_same_rate() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
"""24kHz -> 16kHz should produce fewer samples (ratio 3:2)."""
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
|
||||
assert len(output_samples) == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_clamps_to_int16_range() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [32767, -32768, 32767, -32768, 32767, -32768]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
for s in output_samples:
|
||||
assert -32768 <= s <= 32767
|
||||
|
||||
|
||||
# --- Provider Model Defaulting ---
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_stt_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", stt_model="invalid_model")
|
||||
assert provider.stt_model == "scribe_v1"
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_tts_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", tts_model="invalid_model")
|
||||
assert provider.tts_model == "eleven_multilingual_v2"
|
||||
|
||||
|
||||
def test_provider_accepts_valid_models() -> None:
|
||||
provider = ElevenLabsVoiceProvider(
|
||||
api_key="test", stt_model="scribe_v2_realtime", tts_model="eleven_turbo_v2_5"
|
||||
)
|
||||
assert provider.stt_model == "scribe_v2_realtime"
|
||||
assert provider.tts_model == "eleven_turbo_v2_5"
|
||||
|
||||
|
||||
def test_provider_defaults_api_base() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
assert provider.api_base == DEFAULT_ELEVENLABS_API_BASE
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -1,97 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from onyx.voice.providers.openai import _create_wav_header
|
||||
from onyx.voice.providers.openai import _http_to_ws_url
|
||||
from onyx.voice.providers.openai import OpenAIRealtimeMessageType
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.openai.com") == "wss://api.openai.com"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:9090") == "ws://localhost:9090"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_ws() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_realtime_message_type_compares_as_string() -> None:
|
||||
assert str(OpenAIRealtimeMessageType.ERROR) == "error"
|
||||
assert (
|
||||
str(OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA)
|
||||
== "conversation.item.input_audio_transcription.delta"
|
||||
)
|
||||
assert isinstance(OpenAIRealtimeMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- _create_wav_header ---
|
||||
|
||||
|
||||
def test_wav_header_is_44_bytes() -> None:
|
||||
assert len(_create_wav_header(1000)) == 44
|
||||
|
||||
|
||||
def test_wav_header_chunk_size_matches_data_length() -> None:
|
||||
data_length = 2000
|
||||
header = _create_wav_header(data_length)
|
||||
chunk_size = struct.unpack_from("<I", header, 4)[0]
|
||||
assert chunk_size == 36 + data_length
|
||||
|
||||
|
||||
def test_wav_header_byte_rate() -> None:
|
||||
header = _create_wav_header(100, sample_rate=24000, channels=1, bits_per_sample=16)
|
||||
byte_rate = struct.unpack_from("<I", header, 28)[0]
|
||||
assert byte_rate == 24000 * 1 * 16 // 8
|
||||
|
||||
|
||||
def test_wav_header_produces_valid_wav() -> None:
|
||||
"""Header + PCM data should parse as valid WAV."""
|
||||
data_length = 100
|
||||
pcm_data = b"\x00" * data_length
|
||||
header = _create_wav_header(data_length, sample_rate=24000)
|
||||
|
||||
with wave.open(io.BytesIO(header + pcm_data), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 24000
|
||||
assert wav_file.getnframes() == data_length // 2
|
||||
|
||||
|
||||
# --- Provider Defaults ---
|
||||
|
||||
|
||||
def test_provider_default_models() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
assert provider.stt_model == "whisper-1"
|
||||
assert provider.tts_model == "tts-1"
|
||||
assert provider.default_voice == "alloy"
|
||||
|
||||
|
||||
def test_provider_custom_models() -> None:
|
||||
provider = OpenAIVoiceProvider(
|
||||
api_key="test",
|
||||
stt_model="gpt-4o-transcribe",
|
||||
tts_model="tts-1-hd",
|
||||
default_voice="nova",
|
||||
)
|
||||
assert provider.stt_model == "gpt-4o-transcribe"
|
||||
assert provider.tts_model == "tts-1-hd"
|
||||
assert provider.default_voice == "nova"
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -38,11 +38,6 @@ services:
|
||||
opensearch:
|
||||
ports:
|
||||
- "9200:9200"
|
||||
# Rootless Docker can reject the base OpenSearch ulimit settings, so clear
|
||||
# the inherited block entirely in the dev override.
|
||||
ulimits: !reset null
|
||||
environment:
|
||||
- bootstrap.memory_lock=false
|
||||
|
||||
inference_model_server:
|
||||
ports:
|
||||
|
||||
@@ -35,7 +35,6 @@ backend = [
|
||||
"alembic==1.10.4",
|
||||
"asyncpg==0.30.0",
|
||||
"atlassian-python-api==3.41.16",
|
||||
"azure-cognitiveservices-speech==1.38.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"boto3==1.39.11",
|
||||
"boto3-stubs[s3]==1.39.11",
|
||||
@@ -92,7 +91,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.8.0",
|
||||
"pypdf==6.7.5",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
@@ -154,7 +153,7 @@ dev = [
|
||||
"pytest-repeat==0.9.4",
|
||||
"pytest-xdist==3.8.0",
|
||||
"pytest==8.3.5",
|
||||
"release-tag==0.5.2",
|
||||
"release-tag==0.4.3",
|
||||
"reorder-python-imports-black==3.14.0",
|
||||
"ruff==0.12.0",
|
||||
"types-beautifulsoup4==4.12.0.3",
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/jmelahman/tag/git"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewLatestStableTagCommand creates the latest-stable-tag command.
|
||||
func NewLatestStableTagCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "latest-stable-tag",
|
||||
Short: "Print the git tag that should receive the 'latest' Docker tag",
|
||||
Long: `Print the highest stable (non-pre-release) semver tag in the repository.
|
||||
|
||||
This is used during deployment to decide whether a given tag should
|
||||
receive the "latest" tag on Docker Hub. Only the highest vX.Y.Z tag
|
||||
qualifies. Tags with pre-release suffixes (e.g. v1.2.3-beta,
|
||||
v1.2.3-cloud.1) are excluded.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(c *cobra.Command, _ []string) error {
|
||||
tag, err := git.GetLatestStableSemverTag("")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get latest stable semver tag: %w", err)
|
||||
}
|
||||
if tag == "" {
|
||||
return fmt.Errorf("no stable semver tag found in repository")
|
||||
}
|
||||
fmt.Println(tag)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -52,7 +52,6 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewDesktopCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
return cmd
|
||||
|
||||
@@ -3,13 +3,12 @@ module github.com/onyx-dot-app/onyx/tools/ods
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/jmelahman/tag v0.5.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/pflag v1.0.9
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
|
||||
)
|
||||
|
||||
@@ -4,26 +4,20 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
|
||||
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
41
uv.lock
generated
41
uv.lock
generated
@@ -463,19 +463,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-cognitiveservices-speech"
|
||||
version = "1.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/f4/4571c42cb00f8af317d5431f594b4ece1fbe59ab59f106947fea8e90cf89/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:18dce915ab032711f687abb3297dd19176b9cbea562b322ee6fa7365ef4a5091", size = 6775838, upload-time = "2024-06-11T03:08:35.202Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/22/0ca2c59a573119950cad1f53531fec9872fc38810c405a4e1827f3d13a8e/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9dd0800fbc4a8438c6dfd5747a658251914fe2d205a29e9b46158cadac6ab381", size = 6687975, upload-time = "2024-06-11T03:08:38.797Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/96/5436c09de3af3a9aefaa8cc00533c3a0f5d17aef5bbe017c17f0a30ad66e/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:1c344e8a6faadb063cea451f0301e13b44d9724e1242337039bff601e81e6f86", size = 40022287, upload-time = "2024-06-11T03:08:16.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/2d/ba20d05ff77ec9870cd489e6e7a474ba7fe820524bcf6fd202025e0c11cf/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1e002595a749471efeac3a54c80097946570b76c13049760b97a4b881d9d24af", size = 39788653, upload-time = "2024-06-11T03:08:30.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/21/25f8c37fb6868db4346ca977c287ede9e87f609885d932653243c9ed5f63/azure_cognitiveservices_speech-1.38.0-py3-none-win32.whl", hash = "sha256:16a530e6c646eb49ea0bc05cb45a9d28b99e4b67613f6c3a6c54e26e6bf65241", size = 1428364, upload-time = "2024-06-11T03:08:03.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/05/a6414a3481c5ee30c4f32742abe055e5f3ce4ff69e936089d86ece354067/azure_cognitiveservices_speech-1.38.0-py3-none-win_amd64.whl", hash = "sha256:1d38d8c056fb3f513a9ff27ab4e77fd08ca487f8788cc7a6df772c1ab2c97b54", size = 1539297, upload-time = "2024-06-11T03:08:01.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
@@ -4240,7 +4227,6 @@ backend = [
|
||||
{ name = "asana" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "atlassian-python-api" },
|
||||
{ name = "azure-cognitiveservices-speech" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "boto3" },
|
||||
{ name = "boto3-stubs", extra = ["s3"] },
|
||||
@@ -4395,7 +4381,6 @@ requires-dist = [
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
@@ -4481,7 +4466,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.8.0" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.5" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -4500,7 +4485,7 @@ requires-dist = [
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
@@ -5728,11 +5713,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.7.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6353,16 +6338,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "release-tag"
|
||||
version = "0.5.2"
|
||||
version = "0.4.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { OpenButton } from "@opal/components";
|
||||
import { Disabled as DisabledProvider } from "@opal/core";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
@@ -33,9 +32,16 @@ export const WithIcon: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Selected: Story = {
|
||||
args: {
|
||||
selected: true,
|
||||
children: "Selected",
|
||||
},
|
||||
};
|
||||
|
||||
export const Open: Story = {
|
||||
args: {
|
||||
interaction: "hover",
|
||||
transient: true,
|
||||
children: "Open state",
|
||||
},
|
||||
};
|
||||
@@ -47,27 +53,18 @@ export const Disabled: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Foldable: Story = {
|
||||
export const LightProminence: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
prominence: "light",
|
||||
children: "Light prominence",
|
||||
},
|
||||
};
|
||||
|
||||
export const FoldableDisabled: Story = {
|
||||
export const HeavyProminence: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
prominence: "heavy",
|
||||
children: "Heavy prominence",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const Sizes: Story = {
|
||||
@@ -81,12 +78,3 @@ export const Sizes: Story = {
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const WithTooltip: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
tooltip: "Open settings",
|
||||
tooltipSide: "bottom",
|
||||
},
|
||||
};
|
||||
@@ -17,9 +17,7 @@ OpenButton is a **tighter, specialized use-case** of SelectButton:
|
||||
- It hardcodes `variant="select-heavy"` (SelectButton exposes `variant`)
|
||||
- It adds a built-in chevron with CSS-driven rotation (SelectButton has no chevron)
|
||||
- It auto-detects Radix `data-state="open"` to derive `interaction` (SelectButton has no Radix awareness)
|
||||
- It does not support `rightIcon` (SelectButton does)
|
||||
|
||||
Both components support `foldable` using the same pattern: `interactive-foldable-host` class + `Interactive.Foldable` wrapper around the label and trailing icon. When foldable, the left icon stays visible while the rest collapses. If you change the foldable implementation in one, update the other to match.
|
||||
- It does not support `foldable` or `rightIcon` (SelectButton does)
|
||||
|
||||
If you need a general-purpose stateful toggle, use `SelectButton`. If you need a popover/dropdown trigger with a chevron, use `OpenButton`.
|
||||
|
||||
@@ -28,12 +26,10 @@ If you need a general-purpose stateful toggle, use `SelectButton`. If you need a
|
||||
```
|
||||
Interactive.Stateful <- variant="select-heavy", interaction, state, disabled, onClick
|
||||
└─ Interactive.Container <- height, rounding, padding (from `size`)
|
||||
└─ div.opal-button.interactive-foreground [.interactive-foldable-host]
|
||||
└─ div.opal-button.interactive-foreground
|
||||
├─ div > Icon? (interactive-foreground-icon)
|
||||
├─ [Foldable]? (wraps label + chevron when foldable)
|
||||
│ ├─ <span>? .opal-button-label
|
||||
│ └─ div > ChevronIcon .opal-open-button-chevron
|
||||
└─ <span>? / ChevronIcon (non-foldable)
|
||||
├─ <span>? .opal-button-label
|
||||
└─ div > ChevronIcon .opal-open-button-chevron (interactive-foreground-icon)
|
||||
```
|
||||
|
||||
- **`interaction` controls both the chevron and the hover visual state.** When `interaction` is `"hover"` (explicitly or via Radix `data-state="open"`), the chevron rotates 180° and the hover background activates.
|
||||
@@ -48,7 +44,6 @@ Interactive.Stateful <- variant="select-heavy", interaction, state, di
|
||||
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"` when omitted. |
|
||||
| `icon` | `IconFunctionComponent` | — | Left icon component |
|
||||
| `children` | `string` | — | Content between icon and chevron |
|
||||
| `foldable` | `boolean` | `false` | When `true`, requires both `icon` and `children`; the left icon stays visible while the label + chevron collapse when not hovered. If `tooltip` is omitted on a disabled foldable button, the label text is used as the tooltip. |
|
||||
| `size` | `SizeVariant` | `"lg"` | Size preset controlling height, rounding, and padding |
|
||||
| `width` | `WidthVariant` | — | Width preset |
|
||||
| `tooltip` | `string` | — | Tooltip text shown on hover |
|
||||
|
||||
@@ -2,7 +2,6 @@ import "@opal/components/buttons/open-button/styles.css";
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
@@ -31,46 +30,27 @@ function ChevronIcon({ className, ...props }: IconProps) {
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Content props — a discriminated union on `foldable` that enforces:
|
||||
*
|
||||
* - `foldable: true` → `icon` and `children` are required (icon stays visible,
|
||||
* label + chevron fold away)
|
||||
* - `foldable?: false` → at least one of `icon` or `children` must be provided
|
||||
*/
|
||||
type OpenButtonContentProps =
|
||||
| {
|
||||
foldable: true;
|
||||
icon: IconFunctionComponent;
|
||||
children: string;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string;
|
||||
};
|
||||
type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
/** Left icon. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> &
|
||||
OpenButtonContentProps & {
|
||||
/**
|
||||
* Size preset — controls gap, text size, and Container height/rounding.
|
||||
*/
|
||||
size?: SizeVariant;
|
||||
/** Button label text. */
|
||||
children?: string;
|
||||
|
||||
/** Width preset. */
|
||||
width?: WidthVariant;
|
||||
/**
|
||||
* Size preset — controls gap, text size, and Container height/rounding.
|
||||
*/
|
||||
size?: SizeVariant;
|
||||
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
/** Width preset. */
|
||||
width?: WidthVariant;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
};
|
||||
/** Tooltip text shown on hover. */
|
||||
tooltip?: string;
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OpenButton
|
||||
@@ -80,15 +60,12 @@ function OpenButton({
|
||||
icon: Icon,
|
||||
children,
|
||||
size = "lg",
|
||||
foldable,
|
||||
width,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
interaction,
|
||||
...statefulProps
|
||||
}: OpenButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
|
||||
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
@@ -98,17 +75,6 @@ function OpenButton({
|
||||
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateful
|
||||
variant="select-heavy"
|
||||
@@ -123,34 +89,25 @@ function OpenButton({
|
||||
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
|
||||
}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"opal-button interactive-foreground flex flex-row items-center gap-1",
|
||||
foldable && "interactive-foldable-host"
|
||||
)}
|
||||
>
|
||||
{iconWrapper(Icon, size, !foldable && !!children)}
|
||||
|
||||
{foldable ? (
|
||||
<Interactive.Foldable>
|
||||
{labelEl}
|
||||
{iconWrapper(ChevronIcon, size, !!children)}
|
||||
</Interactive.Foldable>
|
||||
) : (
|
||||
<>
|
||||
{labelEl}
|
||||
{iconWrapper(ChevronIcon, size, !!children)}
|
||||
</>
|
||||
<div className="opal-button interactive-foreground flex flex-row items-center gap-1">
|
||||
{iconWrapper(Icon, size, false)}
|
||||
{children && (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label whitespace-nowrap",
|
||||
isLarge ? "font-main-ui-body" : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
)}
|
||||
{iconWrapper(ChevronIcon, size, false)}
|
||||
</div>
|
||||
</Interactive.Container>
|
||||
</Interactive.Stateful>
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
@@ -161,7 +118,7 @@ function OpenButton({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{resolvedTooltip}
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -17,9 +17,7 @@ Interactive.Stateful → Interactive.Container → content row (icon + label + t
|
||||
- OpenButton hardcodes `variant="select-heavy"` (SelectButton exposes `variant`)
|
||||
- OpenButton adds a built-in chevron with CSS-driven rotation (SelectButton has no chevron)
|
||||
- OpenButton auto-detects Radix `data-state="open"` to derive `interaction` (SelectButton has no Radix awareness)
|
||||
- OpenButton does not support `rightIcon` (SelectButton does)
|
||||
|
||||
Both components support `foldable` using the same pattern: `interactive-foldable-host` class + `Interactive.Foldable` wrapper around the label and trailing icon. When foldable, the left icon stays visible while the rest collapses. If you change the foldable implementation in one, update the other to match.
|
||||
- OpenButton does not support `foldable` or `rightIcon` (SelectButton does)
|
||||
|
||||
Use SelectButton for general-purpose stateful toggles. Use `OpenButton` for popover/dropdown triggers with a chevron.
|
||||
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { Card } from "@opal/components";
|
||||
|
||||
const BACKGROUND_VARIANTS = ["none", "light", "heavy"] as const;
|
||||
const BORDER_VARIANTS = ["none", "dashed", "solid"] as const;
|
||||
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
|
||||
|
||||
const meta: Meta<typeof Card> = {
|
||||
title: "opal/components/Card",
|
||||
component: Card,
|
||||
tags: ["autodocs"],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof Card>;
|
||||
|
||||
export const Default: Story = {
|
||||
render: () => (
|
||||
<Card>
|
||||
<p>Default card with light background, no border, lg size.</p>
|
||||
</Card>
|
||||
),
|
||||
};
|
||||
|
||||
export const BackgroundVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{BACKGROUND_VARIANTS.map((bg) => (
|
||||
<Card key={bg} backgroundVariant={bg} borderVariant="solid">
|
||||
<p>backgroundVariant: {bg}</p>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const BorderVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{BORDER_VARIANTS.map((border) => (
|
||||
<Card key={border} borderVariant={border}>
|
||||
<p>borderVariant: {border}</p>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const SizeVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
<Card key={size} sizeVariant={size} borderVariant="solid">
|
||||
<p>sizeVariant: {size}</p>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const AllCombinations: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-8">
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
<div key={size}>
|
||||
<p className="font-bold pb-2">sizeVariant: {size}</p>
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
{BACKGROUND_VARIANTS.map((bg) =>
|
||||
BORDER_VARIANTS.map((border) => (
|
||||
<Card
|
||||
key={`${size}-${bg}-${border}`}
|
||||
sizeVariant={size}
|
||||
backgroundVariant={bg}
|
||||
borderVariant={border}
|
||||
>
|
||||
<p className="text-xs">
|
||||
bg: {bg}, border: {border}
|
||||
</p>
|
||||
</Card>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -1,67 +0,0 @@
|
||||
# Card
|
||||
|
||||
**Import:** `import { Card, type CardProps } from "@opal/components";`
|
||||
|
||||
A plain container component with configurable background, border, padding, and rounding. Uses a simple `<div>` internally with `overflow-clip`.
|
||||
|
||||
## Architecture
|
||||
|
||||
The `sizeVariant` controls both padding and border-radius, mirroring the same mapping used by `Button` and `Interactive.Container`:
|
||||
|
||||
| Size | Padding | Rounding |
|
||||
|-----------|---------|----------------|
|
||||
| `lg` | `p-2` | `rounded-12` |
|
||||
| `md` | `p-1` | `rounded-08` |
|
||||
| `sm` | `p-1` | `rounded-08` |
|
||||
| `xs` | `p-0.5` | `rounded-04` |
|
||||
| `2xs` | `p-0.5` | `rounded-04` |
|
||||
| `fit` | `p-0` | `rounded-12` |
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `sizeVariant` | `SizeVariant` | `"lg"` | Controls padding and border-radius |
|
||||
| `backgroundVariant` | `"none" \| "light" \| "heavy"` | `"light"` | Background fill intensity |
|
||||
| `borderVariant` | `"none" \| "dashed" \| "solid"` | `"none"` | Border style |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
| `children` | `React.ReactNode` | — | Card content |
|
||||
|
||||
## Background Variants
|
||||
|
||||
- **`none`** — Transparent background. Use for seamless inline content.
|
||||
- **`light`** — Subtle tinted background (`bg-background-tint-00`). The default, suitable for most cards.
|
||||
- **`heavy`** — Stronger tinted background (`bg-background-tint-01`). Use for emphasis or nested cards that need visual separation.
|
||||
|
||||
## Border Variants
|
||||
|
||||
- **`none`** — No border. Use when cards are visually grouped or in tight layouts.
|
||||
- **`dashed`** — Dashed border. Use for placeholder or empty states.
|
||||
- **`solid`** — Solid border. Use for prominent, standalone cards.
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { Card } from "@opal/components";
|
||||
|
||||
// Default card (light background, no border, lg padding + rounding)
|
||||
<Card>
|
||||
<h2>Card Title</h2>
|
||||
<p>Card content</p>
|
||||
</Card>
|
||||
|
||||
// Compact card with solid border
|
||||
<Card borderVariant="solid" sizeVariant="sm">
|
||||
<p>Compact card</p>
|
||||
</Card>
|
||||
|
||||
// Empty state card
|
||||
<Card backgroundVariant="none" borderVariant="dashed">
|
||||
<p>No items yet</p>
|
||||
</Card>
|
||||
|
||||
// Heavy background, tight padding
|
||||
<Card backgroundVariant="heavy" sizeVariant="xs">
|
||||
<p>Highlighted content</p>
|
||||
</Card>
|
||||
```
|
||||
@@ -1,101 +0,0 @@
|
||||
import "@opal/components/cards/card/styles.css";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import { sizeVariants } from "@opal/shared";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BackgroundVariant = "none" | "light" | "heavy";
|
||||
type BorderVariant = "none" | "dashed" | "solid";
|
||||
|
||||
type CardProps = {
|
||||
/**
|
||||
* Size preset — controls padding and border-radius.
|
||||
*
|
||||
* Padding comes from the shared size scale. Rounding follows the same
|
||||
* mapping as `Button` / `Interactive.Container`:
|
||||
*
|
||||
* | Size | Rounding |
|
||||
* |--------|------------|
|
||||
* | `lg` | `default` |
|
||||
* | `md`–`sm` | `compact` |
|
||||
* | `xs`–`2xs` | `mini` |
|
||||
* | `fit` | `default` |
|
||||
*
|
||||
* @default "lg"
|
||||
*/
|
||||
sizeVariant?: SizeVariant;
|
||||
|
||||
/**
|
||||
* Background fill intensity.
|
||||
* - `"none"`: transparent background.
|
||||
* - `"light"`: subtle tinted background (`bg-background-tint-00`).
|
||||
* - `"heavy"`: stronger tinted background (`bg-background-tint-01`).
|
||||
*
|
||||
* @default "light"
|
||||
*/
|
||||
backgroundVariant?: BackgroundVariant;
|
||||
|
||||
/**
|
||||
* Border style.
|
||||
* - `"none"`: no border.
|
||||
* - `"dashed"`: dashed border.
|
||||
* - `"solid"`: solid border.
|
||||
*
|
||||
* @default "none"
|
||||
*/
|
||||
borderVariant?: BorderVariant;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
|
||||
children?: React.ReactNode;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Rounding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Maps a size variant to a rounding class, mirroring the Button pattern. */
|
||||
const roundingForSize: Record<SizeVariant, string> = {
|
||||
lg: "rounded-12",
|
||||
md: "rounded-08",
|
||||
sm: "rounded-08",
|
||||
xs: "rounded-04",
|
||||
"2xs": "rounded-04",
|
||||
fit: "rounded-12",
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Card
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function Card({
|
||||
sizeVariant = "lg",
|
||||
backgroundVariant = "light",
|
||||
borderVariant = "none",
|
||||
ref,
|
||||
children,
|
||||
}: CardProps) {
|
||||
const { padding } = sizeVariants[sizeVariant];
|
||||
const rounding = roundingForSize[sizeVariant];
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("opal-card", padding, rounding)}
|
||||
data-background={backgroundVariant}
|
||||
data-border={borderVariant}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Exports
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export { Card, type CardProps, type BackgroundVariant, type BorderVariant };
|
||||
@@ -1,29 +0,0 @@
|
||||
.opal-card {
|
||||
@apply w-full overflow-clip;
|
||||
}
|
||||
|
||||
/* Background variants */
|
||||
.opal-card[data-background="none"] {
|
||||
@apply bg-transparent;
|
||||
}
|
||||
|
||||
.opal-card[data-background="light"] {
|
||||
@apply bg-background-tint-00;
|
||||
}
|
||||
|
||||
.opal-card[data-background="heavy"] {
|
||||
@apply bg-background-tint-01;
|
||||
}
|
||||
|
||||
/* Border variants */
|
||||
.opal-card[data-border="none"] {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.opal-card[data-border="dashed"] {
|
||||
@apply border border-dashed;
|
||||
}
|
||||
|
||||
.opal-card[data-border="solid"] {
|
||||
@apply border;
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { EmptyMessageCard } from "@opal/components";
|
||||
import { SvgSparkle, SvgUsers } from "@opal/icons";
|
||||
|
||||
const SIZE_VARIANTS = ["lg", "md", "sm", "xs", "2xs", "fit"] as const;
|
||||
|
||||
const meta: Meta<typeof EmptyMessageCard> = {
|
||||
title: "opal/components/EmptyMessageCard",
|
||||
component: EmptyMessageCard,
|
||||
tags: ["autodocs"],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof EmptyMessageCard>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
title: "No items available.",
|
||||
},
|
||||
};
|
||||
|
||||
export const WithCustomIcon: Story = {
|
||||
args: {
|
||||
icon: SvgSparkle,
|
||||
title: "No agents selected.",
|
||||
},
|
||||
};
|
||||
|
||||
export const SizeVariants: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
{SIZE_VARIANTS.map((size) => (
|
||||
<EmptyMessageCard
|
||||
key={size}
|
||||
sizeVariant={size}
|
||||
title={`sizeVariant: ${size}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const Multiple: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col gap-4 w-96">
|
||||
<EmptyMessageCard title="No models available." />
|
||||
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
|
||||
<EmptyMessageCard icon={SvgUsers} title="No groups added." />
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -1,30 +0,0 @@
|
||||
# EmptyMessageCard
|
||||
|
||||
**Import:** `import { EmptyMessageCard, type EmptyMessageCardProps } from "@opal/components";`
|
||||
|
||||
A pre-configured Card for empty states. Renders a transparent card with a dashed border containing a muted icon and message text using the `Content` layout.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
| ------------- | -------------------------- | ---------- | ------------------------------------------------ |
|
||||
| `icon` | `IconFunctionComponent` | `SvgEmpty` | Icon displayed alongside the title |
|
||||
| `title` | `string` | — | Primary message text (required) |
|
||||
| `sizeVariant` | `SizeVariant` | `"lg"` | Size preset controlling padding and rounding |
|
||||
| `ref` | `React.Ref<HTMLDivElement>` | — | Ref forwarded to the root div |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { EmptyMessageCard } from "@opal/components";
|
||||
import { SvgSparkle, SvgFileText } from "@opal/icons";
|
||||
|
||||
// Default empty state
|
||||
<EmptyMessageCard title="No items yet." />
|
||||
|
||||
// With custom icon
|
||||
<EmptyMessageCard icon={SvgSparkle} title="No agents selected." />
|
||||
|
||||
// With custom size
|
||||
<EmptyMessageCard sizeVariant="sm" icon={SvgFileText} title="No documents available." />
|
||||
```
|
||||
@@ -1,57 +0,0 @@
|
||||
import { Card } from "@opal/components/cards/card/components";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgEmpty } from "@opal/icons";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type EmptyMessageCardProps = {
|
||||
/** Icon displayed alongside the title. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Primary message text. */
|
||||
title: string;
|
||||
|
||||
/** Size preset controlling padding and rounding of the card. */
|
||||
sizeVariant?: SizeVariant;
|
||||
|
||||
/** Ref forwarded to the root Card div. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EmptyMessageCard
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function EmptyMessageCard({
|
||||
icon = SvgEmpty,
|
||||
title,
|
||||
sizeVariant = "lg",
|
||||
ref,
|
||||
}: EmptyMessageCardProps) {
|
||||
return (
|
||||
<Card
|
||||
ref={ref}
|
||||
backgroundVariant="none"
|
||||
borderVariant="dashed"
|
||||
sizeVariant={sizeVariant}
|
||||
>
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
sizePreset="secondary"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
/>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Exports
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export { EmptyMessageCard, type EmptyMessageCardProps };
|
||||
@@ -31,17 +31,3 @@ export {
|
||||
type TagProps,
|
||||
type TagColor,
|
||||
} from "@opal/components/tag/components";
|
||||
|
||||
/* Card */
|
||||
export {
|
||||
Card,
|
||||
type CardProps,
|
||||
type BackgroundVariant,
|
||||
type BorderVariant,
|
||||
} from "@opal/components/cards/card/components";
|
||||
|
||||
/* EmptyMessageCard */
|
||||
export {
|
||||
EmptyMessageCard,
|
||||
type EmptyMessageCardProps,
|
||||
} from "@opal/components/cards/empty-message-card/components";
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAudio = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 10V6M5 14V2M11 11V5M14 9V7M8 10V6"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAudio;
|
||||
@@ -17,7 +17,6 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
|
||||
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
|
||||
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudio } from "@opal/icons/audio";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
@@ -107,8 +106,6 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
|
||||
export { default as SvgMcp } from "@opal/icons/mcp";
|
||||
export { default as SvgMenu } from "@opal/icons/menu";
|
||||
export { default as SvgMicrophone } from "@opal/icons/microphone";
|
||||
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
|
||||
export { default as SvgMinus } from "@opal/icons/minus";
|
||||
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
@@ -179,8 +176,6 @@ export { default as SvgUserManage } from "@opal/icons/user-manage";
|
||||
export { default as SvgUserPlus } from "@opal/icons/user-plus";
|
||||
export { default as SvgUserSync } from "@opal/icons/user-sync";
|
||||
export { default as SvgUsers } from "@opal/icons/users";
|
||||
export { default as SvgVolume } from "@opal/icons/volume";
|
||||
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
|
||||
export { default as SvgWallet } from "@opal/icons/wallet";
|
||||
export { default as SvgWorkflow } from "@opal/icons/workflow";
|
||||
export { default as SvgX } from "@opal/icons/x";
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
{/* Microphone body */}
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
{/* Diagonal slash */}
|
||||
<path
|
||||
d="M2 2L14 14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophoneOff;
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophone = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophone;
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M14 6L11 9M11 6L14 9"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolumeOff;
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolume = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolume;
|
||||
87
web/lib/opal/src/layouts/content/BodyLayout.stories.tsx
Normal file
87
web/lib/opal/src/layouts/content/BodyLayout.stories.tsx
Normal file
@@ -0,0 +1,87 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { BodyLayout } from "./BodyLayout";
|
||||
import { SvgSettings, SvgStar, SvgRefreshCw } from "@opal/icons";
|
||||
|
||||
const meta = {
|
||||
title: "Layouts/BodyLayout",
|
||||
component: BodyLayout,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
} satisfies Meta<typeof BodyLayout>;
|
||||
|
||||
export default meta;
|
||||
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Size presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const MainContent: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Last synced 2 minutes ago",
|
||||
},
|
||||
};
|
||||
|
||||
export const MainUi: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Document count: 1,234",
|
||||
},
|
||||
};
|
||||
|
||||
export const Secondary: Story = {
|
||||
args: {
|
||||
sizePreset: "secondary",
|
||||
title: "Updated 5 min ago",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// With icon
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const WithIcon: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Settings",
|
||||
icon: SvgSettings,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Orientations
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Vertical: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Stacked layout",
|
||||
icon: SvgStar,
|
||||
orientation: "vertical",
|
||||
},
|
||||
};
|
||||
|
||||
export const Reverse: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Reverse layout",
|
||||
icon: SvgRefreshCw,
|
||||
orientation: "reverse",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Prominence
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Muted: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Muted body text",
|
||||
prominence: "muted",
|
||||
},
|
||||
};
|
||||
134
web/lib/opal/src/layouts/content/BodyLayout.tsx
Normal file
134
web/lib/opal/src/layouts/content/BodyLayout.tsx
Normal file
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type BodySizePreset = "main-content" | "main-ui" | "secondary";
|
||||
type BodyOrientation = "vertical" | "inline" | "reverse";
|
||||
type BodyProminence = "default" | "muted";
|
||||
|
||||
interface BodyPresetConfig {
|
||||
/** Icon width/height (CSS value). */
|
||||
iconSize: string;
|
||||
/** Tailwind padding class for the icon container. */
|
||||
iconContainerPadding: string;
|
||||
/** Tailwind font class for the title. */
|
||||
titleFont: string;
|
||||
/** Title line-height — also used as icon container min-height (CSS value). */
|
||||
lineHeight: string;
|
||||
/** Gap between icon container and title (CSS value). */
|
||||
gap: string;
|
||||
}
|
||||
|
||||
/** Props for {@link BodyLayout}. Does not support editing or descriptions. */
|
||||
interface BodyLayoutProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text (read-only — editing is not supported). */
|
||||
title: string;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: BodySizePreset;
|
||||
|
||||
/** Layout orientation. Default: `"inline"`. */
|
||||
orientation?: BodyOrientation;
|
||||
|
||||
/** Title prominence. Default: `"default"`. */
|
||||
prominence?: BodyProminence;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const BODY_PRESETS: Record<BodySizePreset, BodyPresetConfig> = {
|
||||
"main-content": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-1",
|
||||
titleFont: "font-main-content-body",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
},
|
||||
"main-ui": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
},
|
||||
secondary: {
|
||||
iconSize: "0.75rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BodyLayout
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function BodyLayout({
|
||||
icon: Icon,
|
||||
title,
|
||||
sizePreset = "main-ui",
|
||||
orientation = "inline",
|
||||
prominence = "default",
|
||||
ref,
|
||||
}: BodyLayoutProps) {
|
||||
const config = BODY_PRESETS[sizePreset];
|
||||
const titleColorClass =
|
||||
prominence === "muted" ? "text-text-03" : "text-text-04";
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className="opal-content-body"
|
||||
data-orientation={orientation}
|
||||
style={{ gap: config.gap }}
|
||||
>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-body-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className="opal-content-body-icon text-text-03"
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-body-title",
|
||||
config.titleFont,
|
||||
titleColorClass
|
||||
)}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
BodyLayout,
|
||||
type BodyLayoutProps,
|
||||
type BodySizePreset,
|
||||
type BodyOrientation,
|
||||
type BodyProminence,
|
||||
};
|
||||
98
web/lib/opal/src/layouts/content/HeadingLayout.stories.tsx
Normal file
98
web/lib/opal/src/layouts/content/HeadingLayout.stories.tsx
Normal file
@@ -0,0 +1,98 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { HeadingLayout } from "./HeadingLayout";
|
||||
import { SvgSettings, SvgStar } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta = {
|
||||
title: "Layouts/HeadingLayout",
|
||||
component: HeadingLayout,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<Story />
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
} satisfies Meta<typeof HeadingLayout>;
|
||||
|
||||
export default meta;
|
||||
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Size presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Headline: Story = {
|
||||
args: {
|
||||
sizePreset: "headline",
|
||||
title: "Welcome to Onyx",
|
||||
description: "Your enterprise search and AI assistant platform.",
|
||||
},
|
||||
};
|
||||
|
||||
export const Section: Story = {
|
||||
args: {
|
||||
sizePreset: "section",
|
||||
title: "Configuration",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// With icon
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const WithIcon: Story = {
|
||||
args: {
|
||||
sizePreset: "headline",
|
||||
title: "Settings",
|
||||
icon: SvgSettings,
|
||||
},
|
||||
};
|
||||
|
||||
export const SectionWithIcon: Story = {
|
||||
args: {
|
||||
sizePreset: "section",
|
||||
variant: "section",
|
||||
title: "Favorites",
|
||||
icon: SvgStar,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Variants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const SectionVariant: Story = {
|
||||
args: {
|
||||
sizePreset: "headline",
|
||||
variant: "section",
|
||||
title: "Inline Icon Heading",
|
||||
icon: SvgSettings,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Editable
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Editable: Story = {
|
||||
args: {
|
||||
sizePreset: "headline",
|
||||
title: "Click to edit me",
|
||||
editable: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const EditableSection: Story = {
|
||||
args: {
|
||||
sizePreset: "section",
|
||||
title: "Editable Section Title",
|
||||
editable: true,
|
||||
description: "This title can be edited inline.",
|
||||
},
|
||||
};
|
||||
218
web/lib/opal/src/layouts/content/HeadingLayout.tsx
Normal file
218
web/lib/opal/src/layouts/content/HeadingLayout.tsx
Normal file
@@ -0,0 +1,218 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type HeadingSizePreset = "headline" | "section";
|
||||
type HeadingVariant = "heading" | "section";
|
||||
|
||||
interface HeadingPresetConfig {
|
||||
/** Icon width/height (CSS value). */
|
||||
iconSize: string;
|
||||
/** Tailwind padding class for the icon container. */
|
||||
iconContainerPadding: string;
|
||||
/** Gap between icon container and content (CSS value). */
|
||||
gap: string;
|
||||
/** Tailwind font class for the title. */
|
||||
titleFont: string;
|
||||
/** Title line-height — also used as icon container min-height (CSS value). */
|
||||
lineHeight: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: SizeVariant;
|
||||
/** Tailwind padding class for the edit button container. */
|
||||
editButtonPadding: string;
|
||||
}
|
||||
|
||||
interface HeadingLayoutProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
|
||||
/** Optional description below the title. */
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/** Size preset. Default: `"headline"`. */
|
||||
sizePreset?: HeadingSizePreset;
|
||||
|
||||
/** Variant controls icon placement. `"heading"` = top, `"section"` = inline. Default: `"heading"`. */
|
||||
variant?: HeadingVariant;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const HEADING_PRESETS: Record<HeadingSizePreset, HeadingPresetConfig> = {
|
||||
headline: {
|
||||
iconSize: "2rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
gap: "0.25rem",
|
||||
titleFont: "font-heading-h2",
|
||||
lineHeight: "2.25rem",
|
||||
editButtonSize: "md",
|
||||
editButtonPadding: "p-1",
|
||||
},
|
||||
section: {
|
||||
iconSize: "1.25rem",
|
||||
iconContainerPadding: "p-1",
|
||||
gap: "0rem",
|
||||
titleFont: "font-heading-h3",
|
||||
lineHeight: "1.75rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0.5",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HeadingLayout
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function HeadingLayout({
|
||||
sizePreset = "headline",
|
||||
variant = "heading",
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
editable,
|
||||
onTitleChange,
|
||||
ref,
|
||||
}: HeadingLayoutProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = HEADING_PRESETS[sizePreset];
|
||||
const iconPlacement = variant === "heading" ? "top" : "left";
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className="opal-content-heading"
|
||||
data-icon-placement={iconPlacement}
|
||||
style={{ gap: iconPlacement === "left" ? config.gap : undefined }}
|
||||
>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-heading-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className="opal-content-heading-icon"
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-heading-body">
|
||||
<div className="opal-content-heading-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-heading-input-sizer">
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-heading-input-mirror",
|
||||
config.titleFont
|
||||
)}
|
||||
>
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-heading-input",
|
||||
config.titleFont,
|
||||
"text-text-04"
|
||||
)}
|
||||
value={editValue}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
size={1}
|
||||
autoFocus
|
||||
onFocus={(e) => e.currentTarget.select()}
|
||||
onBlur={commit}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
style={{ height: config.lineHeight }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-heading-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{editable && !editing && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-heading-edit-button",
|
||||
config.editButtonPadding
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="internal"
|
||||
size={config.editButtonSize}
|
||||
tooltip="Edit"
|
||||
tooltipSide="right"
|
||||
onClick={startEditing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div className="opal-content-heading-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { HeadingLayout, type HeadingLayoutProps, type HeadingSizePreset };
|
||||
154
web/lib/opal/src/layouts/content/LabelLayout.stories.tsx
Normal file
154
web/lib/opal/src/layouts/content/LabelLayout.stories.tsx
Normal file
@@ -0,0 +1,154 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { LabelLayout } from "./LabelLayout";
|
||||
import { SvgSettings, SvgStar } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta = {
|
||||
title: "Layouts/LabelLayout",
|
||||
component: LabelLayout,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<Story />
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
} satisfies Meta<typeof LabelLayout>;
|
||||
|
||||
export default meta;
|
||||
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Size presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const MainContent: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Display Name",
|
||||
},
|
||||
};
|
||||
|
||||
export const MainUi: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Email Address",
|
||||
},
|
||||
};
|
||||
|
||||
export const SecondaryPreset: Story = {
|
||||
args: {
|
||||
sizePreset: "secondary",
|
||||
title: "API Key",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// With description
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const WithDescription: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Workspace Name",
|
||||
description: "The name displayed across your organization.",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// With icon
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const WithIcon: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Settings",
|
||||
icon: SvgSettings,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Optional
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Optional: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Phone Number",
|
||||
optional: true,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Aux icons
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const AuxInfoGray: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Connection Status",
|
||||
auxIcon: "info-gray",
|
||||
},
|
||||
};
|
||||
|
||||
export const AuxWarning: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Rate Limit",
|
||||
auxIcon: "warning",
|
||||
},
|
||||
};
|
||||
|
||||
export const AuxError: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "API Key",
|
||||
auxIcon: "error",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// With tag
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const WithTag: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Knowledge Graph",
|
||||
tag: { title: "Beta", color: "blue" },
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Editable
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const Editable: Story = {
|
||||
args: {
|
||||
sizePreset: "main-ui",
|
||||
title: "Click to edit",
|
||||
editable: true,
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Combined
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const FullFeatured: Story = {
|
||||
args: {
|
||||
sizePreset: "main-content",
|
||||
title: "Custom Field",
|
||||
icon: SvgStar,
|
||||
description: "A custom field with all extras enabled.",
|
||||
optional: true,
|
||||
auxIcon: "info-blue",
|
||||
tag: { title: "New", color: "green" },
|
||||
editable: true,
|
||||
},
|
||||
};
|
||||
286
web/lib/opal/src/layouts/content/LabelLayout.tsx
Normal file
286
web/lib/opal/src/layouts/content/LabelLayout.tsx
Normal file
@@ -0,0 +1,286 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components/buttons/button/components";
|
||||
import { Tag, type TagProps } from "@opal/components/tag/components";
|
||||
import type { SizeVariant } from "@opal/shared";
|
||||
import SvgAlertCircle from "@opal/icons/alert-circle";
|
||||
import SvgAlertTriangle from "@opal/icons/alert-triangle";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import SvgXOctagon from "@opal/icons/x-octagon";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
type LabelSizePreset = "main-content" | "main-ui" | "secondary";
|
||||
|
||||
type LabelAuxIcon = "info-gray" | "info-blue" | "warning" | "error";
|
||||
|
||||
interface LabelPresetConfig {
|
||||
iconSize: string;
|
||||
iconContainerPadding: string;
|
||||
iconColorClass: string;
|
||||
titleFont: string;
|
||||
lineHeight: string;
|
||||
gap: string;
|
||||
/** Button `size` prop for the edit button. Uses the shared `SizeVariant` scale. */
|
||||
editButtonSize: SizeVariant;
|
||||
editButtonPadding: string;
|
||||
optionalFont: string;
|
||||
/** Aux icon size = lineHeight − 2 × p-0.5. */
|
||||
auxIconSize: string;
|
||||
}
|
||||
|
||||
interface LabelLayoutProps {
|
||||
/** Optional icon component. */
|
||||
icon?: IconFunctionComponent;
|
||||
|
||||
/** Main title text. */
|
||||
title: string;
|
||||
|
||||
/** Optional description text below the title. */
|
||||
description?: string;
|
||||
|
||||
/** Enable inline editing of the title. */
|
||||
editable?: boolean;
|
||||
|
||||
/** Called when the user commits an edit. */
|
||||
onTitleChange?: (newTitle: string) => void;
|
||||
|
||||
/** When `true`, renders "(Optional)" beside the title. */
|
||||
optional?: boolean;
|
||||
|
||||
/** Auxiliary status icon rendered beside the title. */
|
||||
auxIcon?: LabelAuxIcon;
|
||||
|
||||
/** Tag rendered beside the title. */
|
||||
tag?: TagProps;
|
||||
|
||||
/** Size preset. Default: `"main-ui"`. */
|
||||
sizePreset?: LabelSizePreset;
|
||||
|
||||
/** Ref forwarded to the root `<div>`. */
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Presets
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const LABEL_PRESETS: Record<LabelSizePreset, LabelPresetConfig> = {
|
||||
"main-content": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-1",
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-main-content-emphasis",
|
||||
lineHeight: "1.5rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "sm",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-content-muted",
|
||||
auxIconSize: "1.25rem",
|
||||
},
|
||||
"main-ui": {
|
||||
iconSize: "1rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
iconColorClass: "text-text-03",
|
||||
titleFont: "font-main-ui-action",
|
||||
lineHeight: "1.25rem",
|
||||
gap: "0.25rem",
|
||||
editButtonSize: "xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-main-ui-muted",
|
||||
auxIconSize: "1rem",
|
||||
},
|
||||
secondary: {
|
||||
iconSize: "0.75rem",
|
||||
iconContainerPadding: "p-0.5",
|
||||
iconColorClass: "text-text-04",
|
||||
titleFont: "font-secondary-action",
|
||||
lineHeight: "1rem",
|
||||
gap: "0.125rem",
|
||||
editButtonSize: "2xs",
|
||||
editButtonPadding: "p-0",
|
||||
optionalFont: "font-secondary-action",
|
||||
auxIconSize: "0.75rem",
|
||||
},
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LabelLayout
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const AUX_ICON_CONFIG: Record<
|
||||
LabelAuxIcon,
|
||||
{ icon: IconFunctionComponent; colorClass: string }
|
||||
> = {
|
||||
"info-gray": { icon: SvgAlertCircle, colorClass: "text-text-02" },
|
||||
"info-blue": { icon: SvgAlertCircle, colorClass: "text-status-info-05" },
|
||||
warning: { icon: SvgAlertTriangle, colorClass: "text-status-warning-05" },
|
||||
error: { icon: SvgXOctagon, colorClass: "text-status-error-05" },
|
||||
};
|
||||
|
||||
function LabelLayout({
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
editable,
|
||||
onTitleChange,
|
||||
optional,
|
||||
auxIcon,
|
||||
tag,
|
||||
sizePreset = "main-ui",
|
||||
ref,
|
||||
}: LabelLayoutProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = LABEL_PRESETS[sizePreset];
|
||||
|
||||
function startEditing() {
|
||||
setEditValue(title);
|
||||
setEditing(true);
|
||||
}
|
||||
|
||||
function commit() {
|
||||
const value = editValue.trim();
|
||||
if (value && value !== title) onTitleChange?.(value);
|
||||
setEditing(false);
|
||||
}
|
||||
|
||||
return (
|
||||
<div ref={ref} className="opal-content-label" style={{ gap: config.gap }}>
|
||||
{Icon && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-label-icon-container shrink-0",
|
||||
config.iconContainerPadding
|
||||
)}
|
||||
style={{ minHeight: config.lineHeight }}
|
||||
>
|
||||
<Icon
|
||||
className={cn("opal-content-label-icon", config.iconColorClass)}
|
||||
style={{ width: config.iconSize, height: config.iconSize }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="opal-content-label-body">
|
||||
<div className="opal-content-label-title-row">
|
||||
{editing ? (
|
||||
<div className="opal-content-label-input-sizer">
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-label-input-mirror",
|
||||
config.titleFont
|
||||
)}
|
||||
>
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-label-input",
|
||||
config.titleFont,
|
||||
"text-text-04"
|
||||
)}
|
||||
value={editValue}
|
||||
onChange={(e) => setEditValue(e.target.value)}
|
||||
size={1}
|
||||
autoFocus
|
||||
onFocus={(e) => e.currentTarget.select()}
|
||||
onBlur={commit}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter") commit();
|
||||
if (e.key === "Escape") {
|
||||
setEditValue(title);
|
||||
setEditing(false);
|
||||
}
|
||||
}}
|
||||
style={{ height: config.lineHeight }}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-content-label-title",
|
||||
config.titleFont,
|
||||
"text-text-04",
|
||||
editable && "cursor-pointer"
|
||||
)}
|
||||
onClick={editable ? startEditing : undefined}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
{title}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{optional && (
|
||||
<span
|
||||
className={cn(config.optionalFont, "text-text-03 shrink-0")}
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
(Optional)
|
||||
</span>
|
||||
)}
|
||||
|
||||
{auxIcon &&
|
||||
(() => {
|
||||
const { icon: AuxIcon, colorClass } = AUX_ICON_CONFIG[auxIcon];
|
||||
return (
|
||||
<div
|
||||
className="opal-content-label-aux-icon shrink-0 p-0.5"
|
||||
style={{ height: config.lineHeight }}
|
||||
>
|
||||
<AuxIcon
|
||||
className={colorClass}
|
||||
style={{
|
||||
width: config.auxIconSize,
|
||||
height: config.auxIconSize,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{tag && <Tag {...tag} />}
|
||||
|
||||
{editable && !editing && (
|
||||
<div
|
||||
className={cn(
|
||||
"opal-content-label-edit-button",
|
||||
config.editButtonPadding
|
||||
)}
|
||||
>
|
||||
<Button
|
||||
icon={SvgEdit}
|
||||
prominence="internal"
|
||||
size={config.editButtonSize}
|
||||
tooltip="Edit"
|
||||
tooltipSide="right"
|
||||
onClick={startEditing}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{description && (
|
||||
<div className="opal-content-label-description font-secondary-body text-text-03">
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
LabelLayout,
|
||||
type LabelLayoutProps,
|
||||
type LabelSizePreset,
|
||||
type LabelAuxIcon,
|
||||
};
|
||||
@@ -59,7 +59,7 @@ const nextConfig = {
|
||||
{
|
||||
key: "Permissions-Policy",
|
||||
value:
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="currentColor"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 206 B |
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="white"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="white"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 192 B |
@@ -14,7 +14,6 @@ import {
|
||||
QwenIcon,
|
||||
OllamaIcon,
|
||||
LMStudioIcon,
|
||||
LiteLLMIcon,
|
||||
ZAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
@@ -22,14 +21,12 @@ import {
|
||||
OpenRouterModelResponse,
|
||||
BedrockModelResponse,
|
||||
LMStudioModelResponse,
|
||||
LiteLLMProxyModelResponse,
|
||||
ModelConfiguration,
|
||||
LLMProviderName,
|
||||
BedrockFetchParams,
|
||||
OllamaFetchParams,
|
||||
LMStudioFetchParams,
|
||||
OpenRouterFetchParams,
|
||||
LiteLLMProxyFetchParams,
|
||||
} from "@/interfaces/llm";
|
||||
import { SvgAws, SvgOpenrouter } from "@opal/icons";
|
||||
|
||||
@@ -40,7 +37,6 @@ export const AGGREGATOR_PROVIDERS = new Set([
|
||||
"openrouter",
|
||||
"ollama_chat",
|
||||
"lm_studio",
|
||||
"litellm_proxy",
|
||||
"vertex_ai",
|
||||
]);
|
||||
|
||||
@@ -77,7 +73,6 @@ export const getProviderIcon = (
|
||||
bedrock: SvgAws,
|
||||
bedrock_converse: SvgAws,
|
||||
openrouter: SvgOpenrouter,
|
||||
litellm_proxy: LiteLLMIcon,
|
||||
vertex_ai: GeminiIcon,
|
||||
};
|
||||
|
||||
@@ -343,65 +338,6 @@ export const fetchLMStudioModels = async (
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches LiteLLM Proxy models directly without any form state dependencies.
|
||||
* Uses snake_case params to match API structure.
|
||||
*/
|
||||
export const fetchLiteLLMProxyModels = async (
|
||||
params: LiteLLMProxyFetchParams
|
||||
): Promise<{ models: ModelConfiguration[]; error?: string }> => {
|
||||
const apiBase = params.api_base;
|
||||
const apiKey = params.api_key;
|
||||
if (!apiBase) {
|
||||
return { models: [], error: "API Base is required" };
|
||||
}
|
||||
if (!apiKey) {
|
||||
return { models: [], error: "API Key is required" };
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/llm/litellm/available-models", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
api_base: apiBase,
|
||||
api_key: apiKey,
|
||||
provider_name: params.provider_name,
|
||||
}),
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = "Failed to fetch models";
|
||||
try {
|
||||
const errorData = await response.json();
|
||||
errorMessage = errorData.detail || errorData.message || errorMessage;
|
||||
} catch {
|
||||
// ignore JSON parsing errors
|
||||
}
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
|
||||
const data: LiteLLMProxyModelResponse[] = await response.json();
|
||||
const models: ModelConfiguration[] = data.map((modelData) => ({
|
||||
name: modelData.model_name,
|
||||
display_name: modelData.model_name,
|
||||
is_visible: true,
|
||||
max_input_tokens: null,
|
||||
supports_image_input: false,
|
||||
supports_reasoning: false,
|
||||
}));
|
||||
|
||||
return { models };
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error";
|
||||
return { models: [], error: errorMessage };
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches models for a provider. Accepts form values directly and maps them
|
||||
* to the expected fetch params format internally.
|
||||
@@ -449,13 +385,6 @@ export const fetchModels = async (
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
});
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return fetchLiteLLMProxyModels({
|
||||
api_base: formValues.api_base,
|
||||
api_key: formValues.api_key,
|
||||
provider_name: formValues.name,
|
||||
signal,
|
||||
});
|
||||
default:
|
||||
return { models: [], error: `Unknown provider: ${providerName}` };
|
||||
}
|
||||
@@ -468,7 +397,6 @@ export function canProviderFetchModels(providerName?: string) {
|
||||
case LLMProviderName.OLLAMA_CHAT:
|
||||
case LLMProviderName.LM_STUDIO:
|
||||
case LLMProviderName.OPENROUTER:
|
||||
case LLMProviderName.LITELLM_PROXY:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -1,507 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { FunctionComponent, useState, useEffect } from "react";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { Vertical, Horizontal } from "@/layouts/input-layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgArrowExchange, SvgOnyxLogo } from "@opal/icons";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { VoiceProviderView } from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
testVoiceProvider,
|
||||
upsertVoiceProvider,
|
||||
fetchVoicesByType,
|
||||
fetchLLMProviders,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
|
||||
interface VoiceOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface LLMProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
}
|
||||
|
||||
interface ApiKeyOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface VoiceProviderSetupModalProps {
|
||||
providerType: string;
|
||||
existingProvider: VoiceProviderView | null;
|
||||
mode: "stt" | "tts";
|
||||
defaultModelId?: string | null;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
const PROVIDER_LABELS: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure Speech Services",
|
||||
elevenlabs: "ElevenLabs",
|
||||
};
|
||||
|
||||
const PROVIDER_API_KEY_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/api-keys",
|
||||
azure: "https://portal.azure.com/",
|
||||
elevenlabs: "https://elevenlabs.io/app/settings/api-keys",
|
||||
};
|
||||
|
||||
const PROVIDER_LOGO_URLS: Record<string, string> = {
|
||||
openai: "/Openai.svg",
|
||||
azure: "/Azure.png",
|
||||
elevenlabs: "/ElevenLabs.svg",
|
||||
};
|
||||
|
||||
const PROVIDER_DOCS_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/docs/guides/text-to-speech",
|
||||
azure: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
elevenlabs: "https://elevenlabs.io/docs",
|
||||
};
|
||||
|
||||
const PROVIDER_VOICE_DOCS_URLS: Record<string, { url: string; label: string }> =
|
||||
{
|
||||
openai: {
|
||||
url: "https://platform.openai.com/docs/guides/text-to-speech#voice-options",
|
||||
label: "OpenAI",
|
||||
},
|
||||
azure: {
|
||||
url: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts",
|
||||
label: "Azure",
|
||||
},
|
||||
elevenlabs: {
|
||||
url: "https://elevenlabs.io/docs/voices/premade-voices",
|
||||
label: "ElevenLabs",
|
||||
},
|
||||
};
|
||||
|
||||
const OPENAI_STT_MODELS = [{ id: "whisper-1", name: "Whisper v1" }];
|
||||
|
||||
const OPENAI_TTS_MODELS = [
|
||||
{ id: "tts-1", name: "TTS-1" },
|
||||
{ id: "tts-1-hd", name: "TTS-1 HD" },
|
||||
];
|
||||
|
||||
// Map model IDs from cards to actual API model IDs
|
||||
const MODEL_ID_MAP: Record<string, string> = {
|
||||
"tts-1": "tts-1",
|
||||
"tts-1-hd": "tts-1-hd",
|
||||
whisper: "whisper-1",
|
||||
};
|
||||
|
||||
export default function VoiceProviderSetupModal({
|
||||
providerType,
|
||||
existingProvider,
|
||||
mode,
|
||||
defaultModelId,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: VoiceProviderSetupModalProps) {
|
||||
// Map the card model ID to the actual API model ID
|
||||
// Prioritize defaultModelId (from the clicked card) over stored value
|
||||
const initialTtsModel = defaultModelId
|
||||
? MODEL_ID_MAP[defaultModelId] ?? "tts-1"
|
||||
: existingProvider?.tts_model ?? "tts-1";
|
||||
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiKeyChanged, setApiKeyChanged] = useState(false);
|
||||
const [targetUri, setTargetUri] = useState(
|
||||
existingProvider?.target_uri ?? ""
|
||||
);
|
||||
const [selectedLlmProviderId, setSelectedLlmProviderId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
const [sttModel, setSttModel] = useState(
|
||||
existingProvider?.stt_model ?? "whisper-1"
|
||||
);
|
||||
const [ttsModel, setTtsModel] = useState(initialTtsModel);
|
||||
const [defaultVoice, setDefaultVoice] = useState(
|
||||
existingProvider?.default_voice ?? ""
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
// Dynamic voices fetched from backend
|
||||
const [voiceOptions, setVoiceOptions] = useState<VoiceOption[]>([]);
|
||||
const [isLoadingVoices, setIsLoadingVoices] = useState(false);
|
||||
|
||||
// Existing OpenAI LLM providers for API key reuse
|
||||
const [existingApiKeyOptions, setExistingApiKeyOptions] = useState<
|
||||
ApiKeyOption[]
|
||||
>([]);
|
||||
const [llmProviderMap, setLlmProviderMap] = useState<Map<string, number>>(
|
||||
new Map()
|
||||
);
|
||||
|
||||
// Fetch existing OpenAI LLM providers (for API key reuse)
|
||||
useEffect(() => {
|
||||
if (providerType !== "openai") return;
|
||||
|
||||
fetchLLMProviders()
|
||||
.then((res) => res.json())
|
||||
.then((data: LLMProviderView[]) => {
|
||||
const openaiProviders = data.filter(
|
||||
(p) => p.provider === "openai" && p.api_key
|
||||
);
|
||||
const options: ApiKeyOption[] = openaiProviders.map((p) => ({
|
||||
value: p.api_key!,
|
||||
label: p.api_key!,
|
||||
description: `Used for LLM provider ${p.name}`,
|
||||
}));
|
||||
setExistingApiKeyOptions(options);
|
||||
|
||||
// Map masked API keys to provider IDs for lookup on selection
|
||||
const providerMap = new Map<string, number>();
|
||||
openaiProviders.forEach((p) => {
|
||||
if (p.api_key) {
|
||||
providerMap.set(p.api_key, p.id);
|
||||
}
|
||||
});
|
||||
setLlmProviderMap(providerMap);
|
||||
})
|
||||
.catch(() => {
|
||||
setExistingApiKeyOptions([]);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
// Fetch voices on mount (works without API key for ElevenLabs/OpenAI)
|
||||
useEffect(() => {
|
||||
setIsLoadingVoices(true);
|
||||
fetchVoicesByType(providerType)
|
||||
.then((res) => res.json())
|
||||
.then((data: Array<{ id: string; name: string }>) => {
|
||||
const options = data.map((v) => ({
|
||||
value: v.id,
|
||||
label: v.name,
|
||||
description: v.id,
|
||||
}));
|
||||
setVoiceOptions(options);
|
||||
// Set default voice to first option if not already set,
|
||||
// or if current value doesn't exist in the new options
|
||||
setDefaultVoice((prev) => {
|
||||
if (!prev) return options[0]?.value ?? "";
|
||||
const existsInOptions = options.some((opt) => opt.value === prev);
|
||||
return existsInOptions ? prev : options[0]?.value ?? "";
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
setVoiceOptions([]);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingVoices(false);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
const isEditing = !!existingProvider;
|
||||
const label = PROVIDER_LABELS[providerType] ?? providerType;
|
||||
|
||||
// Logo arrangement component for the modal header
|
||||
// No useMemo needed - providerType and label are stable props
|
||||
const LogoArrangement: FunctionComponent<IconProps> = () => (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex items-center justify-center size-7 shrink-0 overflow-clip">
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={24} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={24} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={24} />
|
||||
) : (
|
||||
<Image
|
||||
src={PROVIDER_LOGO_URLS[providerType] ?? "/Openai.svg"}
|
||||
alt={`${label} logo`}
|
||||
width={24}
|
||||
height={24}
|
||||
className="object-contain"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-4 shrink-0">
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
// API key required for new providers, or when explicitly changed during edit
|
||||
if (!selectedLlmProviderId) {
|
||||
if (!isEditing && !apiKey) {
|
||||
toast.error("API key is required");
|
||||
return;
|
||||
}
|
||||
if (isEditing && apiKeyChanged && !apiKey) {
|
||||
toast.error(
|
||||
"API key cannot be empty. Leave blank to keep existing key."
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (providerType === "azure" && !isEditing && !targetUri) {
|
||||
toast.error("Target URI is required");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
// Test the connection first (skip if reusing LLM provider key - it's already validated)
|
||||
if (!selectedLlmProviderId) {
|
||||
const testResponse = await testVoiceProvider({
|
||||
provider_type: providerType,
|
||||
api_key: apiKeyChanged ? apiKey : undefined,
|
||||
target_uri: targetUri || undefined,
|
||||
use_stored_key: isEditing && !apiKeyChanged,
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
const data = await testResponse.json();
|
||||
toast.error(data.detail || "Connection test failed");
|
||||
setIsSubmitting(false);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Save the provider
|
||||
const response = await upsertVoiceProvider({
|
||||
id: existingProvider?.id,
|
||||
name: label,
|
||||
provider_type: providerType,
|
||||
api_key: selectedLlmProviderId
|
||||
? undefined
|
||||
: apiKeyChanged
|
||||
? apiKey
|
||||
: undefined,
|
||||
api_key_changed: selectedLlmProviderId ? false : apiKeyChanged,
|
||||
target_uri: targetUri || undefined,
|
||||
llm_provider_id: selectedLlmProviderId,
|
||||
stt_model: sttModel,
|
||||
tts_model: ttsModel,
|
||||
default_voice: defaultVoice,
|
||||
activate_stt: mode === "stt",
|
||||
activate_tts: mode === "tts",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
toast.success(isEditing ? "Provider updated" : "Provider connected");
|
||||
onSuccess();
|
||||
} else {
|
||||
const data = await response.json();
|
||||
toast.error(data.detail || "Failed to save provider");
|
||||
}
|
||||
} catch {
|
||||
toast.error("Failed to save provider");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={LogoArrangement}
|
||||
title={isEditing ? `Edit ${label}` : `Set up ${label}`}
|
||||
description={`Connect to ${label} and set up your voice models.`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<Vertical
|
||||
title="API Key"
|
||||
subDescription={
|
||||
isEditing ? (
|
||||
"Leave blank to keep existing key"
|
||||
) : (
|
||||
<>
|
||||
Paste your{" "}
|
||||
<a
|
||||
href={PROVIDER_API_KEY_URLS[providerType]}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
API key
|
||||
</a>{" "}
|
||||
from {label} to access your models.
|
||||
</>
|
||||
)
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
{providerType === "openai" && existingApiKeyOptions.length > 0 ? (
|
||||
<InputComboBox
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
setSelectedLlmProviderId(null);
|
||||
}}
|
||||
onValueChange={(value) => {
|
||||
setApiKey(value);
|
||||
// Check if this is an existing key
|
||||
const llmProviderId = llmProviderMap.get(value);
|
||||
if (llmProviderId) {
|
||||
setSelectedLlmProviderId(llmProviderId);
|
||||
setApiKeyChanged(false);
|
||||
} else {
|
||||
setSelectedLlmProviderId(null);
|
||||
setApiKeyChanged(true);
|
||||
}
|
||||
}}
|
||||
options={existingApiKeyOptions}
|
||||
separatorLabel="Reuse OpenAI API Keys"
|
||||
strict={false}
|
||||
showAddPrefix
|
||||
/>
|
||||
) : (
|
||||
<InputTypeIn
|
||||
type="password"
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Vertical>
|
||||
|
||||
{providerType === "azure" && (
|
||||
<Vertical
|
||||
title="Target URI"
|
||||
subDescription={
|
||||
<>
|
||||
Paste the endpoint shown in{" "}
|
||||
<a
|
||||
href="https://portal.azure.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Azure Portal (Keys and Endpoint)
|
||||
</a>
|
||||
. Onyx extracts the speech region from this URL. Examples:
|
||||
https://westus.api.cognitive.microsoft.com/ or
|
||||
https://westus.tts.speech.microsoft.com/.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
placeholder={
|
||||
isEditing
|
||||
? "Leave blank to keep existing"
|
||||
: "https://<region>.api.cognitive.microsoft.com/"
|
||||
}
|
||||
value={targetUri}
|
||||
onChange={(e) => setTargetUri(e.target.value)}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "stt" && (
|
||||
<Horizontal title="STT Model" center nonInteractive>
|
||||
<InputSelect value={sttModel} onValueChange={setSttModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_STT_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Horizontal>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "tts" && (
|
||||
<Vertical
|
||||
title="Default Model"
|
||||
subDescription="This model will be used by Onyx by default for text-to-speech."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect value={ttsModel} onValueChange={setTtsModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_TTS_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{mode === "tts" && (
|
||||
<Vertical
|
||||
title="Voice"
|
||||
subDescription={
|
||||
<>
|
||||
This voice will be used for spoken responses. See full list
|
||||
of supported languages and voices at{" "}
|
||||
<a
|
||||
href={
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
|
||||
PROVIDER_DOCS_URLS[providerType]
|
||||
}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
{PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label}
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputComboBox
|
||||
value={defaultVoice}
|
||||
onValueChange={setDefaultVoice}
|
||||
options={voiceOptions}
|
||||
placeholder={
|
||||
isLoadingVoices
|
||||
? "Loading voices..."
|
||||
: "Select a voice or enter voice ID"
|
||||
}
|
||||
disabled={isLoadingVoices}
|
||||
strict={false}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button secondary onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSubmit} disabled={isSubmitting}>
|
||||
{isSubmitting ? "Connecting..." : isEditing ? "Save" : "Connect"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,630 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
InfoIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import {
|
||||
useVoiceProviders,
|
||||
VoiceProviderView,
|
||||
} from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
activateVoiceProvider,
|
||||
deactivateVoiceProvider,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgAudio,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgMicrophone,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import VoiceProviderSetupModal from "./VoiceProviderSetupModal";
|
||||
|
||||
interface ModelDetails {
|
||||
id: string;
|
||||
label: string;
|
||||
subtitle: string;
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
interface ProviderGroup {
|
||||
providerType: string;
|
||||
providerLabel: string;
|
||||
logoSrc?: string;
|
||||
models: ModelDetails[];
|
||||
}
|
||||
|
||||
// STT Models - individual cards
|
||||
const STT_MODELS: ModelDetails[] = [
|
||||
{
|
||||
id: "whisper",
|
||||
label: "Whisper",
|
||||
subtitle: "OpenAI's general purpose speech recognition model.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "azure-speech-stt",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Speech to text in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
{
|
||||
id: "elevenlabs-stt",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Speech to Text API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
];
|
||||
|
||||
// TTS Models - grouped by provider
|
||||
const TTS_PROVIDER_GROUPS: ProviderGroup[] = [
|
||||
{
|
||||
providerType: "openai",
|
||||
providerLabel: "OpenAI",
|
||||
logoSrc: "/Openai.svg",
|
||||
models: [
|
||||
{
|
||||
id: "tts-1",
|
||||
label: "TTS-1",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for speed.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "tts-1-hd",
|
||||
label: "TTS-1 HD",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for quality.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "azure",
|
||||
providerLabel: "Azure",
|
||||
logoSrc: "/Azure.png",
|
||||
models: [
|
||||
{
|
||||
id: "azure-speech-tts",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Text to speech in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "elevenlabs",
|
||||
providerLabel: "ElevenLabs",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
models: [
|
||||
{
|
||||
id: "elevenlabs-tts",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Text to Speech API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
const [editingProvider, setEditingProvider] =
|
||||
useState<VoiceProviderView | null>(null);
|
||||
const [modalMode, setModalMode] = useState<ProviderMode>("stt");
|
||||
const [selectedModelId, setSelectedModelId] = useState<string | null>(null);
|
||||
const [sttActivationError, setSTTActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
|
||||
|
||||
const handleConnect = (
|
||||
providerType: string,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(providerType);
|
||||
setEditingProvider(null);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
setSTTActivationError(null);
|
||||
setTTSActivationError(null);
|
||||
};
|
||||
|
||||
const handleEdit = (
|
||||
provider: VoiceProviderView,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(provider.provider_type);
|
||||
setEditingProvider(provider);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const handleSetDefault = async (
|
||||
providerId: number,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await activateVoiceProvider(providerId, mode, modelId);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to set provider as default ${mode.toUpperCase()}.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeactivate = async (providerId: number, mode: ProviderMode) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await deactivateVoiceProvider(providerId, mode);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to deactivate ${mode.toUpperCase()} provider.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalClose = () => {
|
||||
setModalOpen(false);
|
||||
setSelectedProvider(null);
|
||||
setEditingProvider(null);
|
||||
setSelectedModelId(null);
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
mutate();
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
|
||||
// Map provider types to their configured provider data
|
||||
const providersByType = useMemo(() => {
|
||||
return new Map((providers ?? []).map((p) => [p.provider_type, p] as const));
|
||||
}, [providers]);
|
||||
|
||||
const hasActiveSTTProvider =
|
||||
providers?.some((p) => p.is_default_stt) ?? false;
|
||||
const hasActiveTTSProvider =
|
||||
providers?.some((p) => p.is_default_tts) ?? false;
|
||||
|
||||
const renderLogo = ({
|
||||
logoSrc,
|
||||
providerType,
|
||||
alt,
|
||||
size = 16,
|
||||
}: {
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
alt: string;
|
||||
size?: number;
|
||||
}) => {
|
||||
const containerSizeClass = size === 24 ? "size-7" : "size-5";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center px-0.5 py-0 shrink-0 overflow-clip",
|
||||
containerSizeClass
|
||||
)}
|
||||
>
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={size} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={size} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={size} />
|
||||
) : logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={alt}
|
||||
width={size}
|
||||
height={size}
|
||||
className="object-contain"
|
||||
/>
|
||||
) : (
|
||||
<SvgMicrophone size={size} className="text-text-02" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderModelCard = ({
|
||||
model,
|
||||
mode,
|
||||
}: {
|
||||
model: ModelDetails;
|
||||
mode: ProviderMode;
|
||||
}) => {
|
||||
const provider = providersByType.get(model.providerType);
|
||||
const isConfigured = isProviderConfigured(provider);
|
||||
// For TTS, also check that this specific model is the default (not just the provider)
|
||||
const isActive =
|
||||
mode === "stt"
|
||||
? provider?.is_default_stt
|
||||
: provider?.is_default_tts && provider?.tts_model === model.id;
|
||||
const isHighlighted = isActive ?? false;
|
||||
const providerId = provider?.id;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: () => handleConnect(model.providerType, mode, model.id),
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
? () => handleDeactivate(providerId, mode)
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => handleSetDefault(providerId, mode, model.id)
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `${mode}-${model.id}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${mode}-${model.id}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-4 rounded-16 border p-2 bg-background-neutral-01",
|
||||
isHighlighted ? "border-action-link-05" : "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-2.5 p-2">
|
||||
{renderLogo({
|
||||
logoSrc: model.logoSrc,
|
||||
providerType: model.providerType,
|
||||
alt: `${model.label} logo`,
|
||||
size: 16,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text04>
|
||||
{model.label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{model.subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-1.5 self-center">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
aria-label={`Edit ${model.label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Button
|
||||
action={false}
|
||||
tertiary
|
||||
disabled={buttonState.disabled || !buttonState.onClick}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (error) {
|
||||
const message = error?.message || "Unable to load voice configuration.";
|
||||
const detail =
|
||||
error instanceof FetchError && typeof error.info?.detail === "string"
|
||||
? error.info.detail
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgAudio} title="Voice" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" secondaryBody text03>
|
||||
Speech to text (STT) and text to speech (TTS) capabilities.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
{/* Speech-to-Text Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Speech to Text
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to transcribe speech to text in chats.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{sttActivationError && (
|
||||
<Callout type="danger" title="Unable to update STT provider">
|
||||
{sttActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveSTTProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a speech to text provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{STT_MODELS.map((model) => renderModelCard({ model, mode: "stt" }))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Text-to-Speech Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Text to Speech
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to speak out chat responses.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{ttsActivationError && (
|
||||
<Callout type="danger" title="Unable to update TTS provider">
|
||||
{ttsActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveTTSProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a text to speech provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-4">
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text as="p" secondaryBody text03 className="px-0.5">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
{group.models.map((model) =>
|
||||
renderModelCard({ model, mode: "tts" })
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
existingProvider={editingProvider}
|
||||
mode={modalMode}
|
||||
defaultModelId={selectedModelId}
|
||||
onClose={handleModalClose}
|
||||
onSuccess={handleModalSuccess}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
export { default } from "@/refresh-pages/admin/UsersPage";
|
||||
@@ -3,7 +3,6 @@ import type { Route } from "next";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { requireAuth } from "@/lib/auth/requireAuth";
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
|
||||
export interface LayoutProps {
|
||||
@@ -22,15 +21,10 @@ export default async function Layout({ children }: LayoutProps) {
|
||||
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
{/* VoiceModeProvider wraps the full app layout so TTS playback state
|
||||
persists across page navigations (e.g., sidebar clicks during playback).
|
||||
It only activates WebSocket connections when TTS is actually triggered. */}
|
||||
<VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useRef,
|
||||
RefObject,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
} from "react";
|
||||
import React, { useRef, RefObject, useMemo } from "react";
|
||||
import { Packet, StopReason } from "@/app/app/services/streamingModels";
|
||||
import CustomToolAuthCard from "@/app/app/message/messageComponents/CustomToolAuthCard";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
@@ -22,9 +16,6 @@ import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { getTextContent } from "@/app/app/services/packetUtils";
|
||||
import { removeThinkingTokens } from "@/app/app/services/thinkingTokens";
|
||||
|
||||
// Type for the regeneration factory function passed from ChatUI
|
||||
export type RegenerationFactory = (regenerationRequest: {
|
||||
@@ -84,7 +75,6 @@ function arePropsEqual(
|
||||
|
||||
const AgentMessage = React.memo(function AgentMessage({
|
||||
rawPackets,
|
||||
packetCount,
|
||||
chatState,
|
||||
nodeId,
|
||||
messageId,
|
||||
@@ -172,59 +162,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onMessageSelection,
|
||||
});
|
||||
|
||||
// Streaming TTS integration
|
||||
const { streamTTS, resetTTS, stopTTS } = useVoiceMode();
|
||||
const ttsCompletedRef = useRef(false);
|
||||
const streamTTSRef = useRef(streamTTS);
|
||||
|
||||
// Keep streamTTS ref in sync without triggering effect re-runs
|
||||
useEffect(() => {
|
||||
streamTTSRef.current = streamTTS;
|
||||
}, [streamTTS]);
|
||||
|
||||
// Stream TTS as text content arrives - only for messages still streaming
|
||||
// Uses ref for streamTTS to avoid re-triggering when its identity changes
|
||||
// Note: packetCount is used instead of rawPackets because the array is mutated in place
|
||||
useLayoutEffect(() => {
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const textContent = removeThinkingTokens(getTextContent(rawPackets));
|
||||
if (typeof textContent === "string" && textContent.length > 0) {
|
||||
streamTTSRef.current(textContent, isComplete, nodeId);
|
||||
|
||||
// Mark as completed once the message is done streaming
|
||||
if (isComplete) {
|
||||
ttsCompletedRef.current = true;
|
||||
}
|
||||
}
|
||||
}, [packetCount, isComplete, rawPackets, nodeId, stopPacketSeen, stopReason]); // packetCount triggers on new packets since rawPackets is mutated in place
|
||||
|
||||
// Stop TTS immediately when user cancels generation.
|
||||
useEffect(() => {
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
stopTTS({ manual: true });
|
||||
}
|
||||
}, [stopPacketSeen, stopReason, stopTTS]);
|
||||
|
||||
// Reset TTS completed flag when nodeId changes (new message)
|
||||
useEffect(() => {
|
||||
ttsCompletedRef.current = false;
|
||||
}, [nodeId]);
|
||||
|
||||
// Reset TTS when component unmounts or nodeId changes
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
resetTTS();
|
||||
};
|
||||
}, [nodeId, resetTTS]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col gap-3"
|
||||
@@ -271,8 +208,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
key={`${displayGroup.turn_index}-${displayGroup.tab_index}`}
|
||||
packets={displayGroup.packets}
|
||||
chatState={effectiveChatState}
|
||||
messageNodeId={nodeId}
|
||||
hasTimelineThinking={pacedTurnGroups.length > 0 || hasSteps}
|
||||
onComplete={() => {
|
||||
// Only mark complete on the last display group
|
||||
// Hook handles the finalAnswerComing check internally
|
||||
|
||||
@@ -29,9 +29,6 @@ import FeedbackModal, {
|
||||
FeedbackModalProps,
|
||||
} from "@/sections/modals/FeedbackModal";
|
||||
import { Button, SelectButton } from "@opal/components";
|
||||
import TTSButton from "./TTSButton";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { useVoiceStatus } from "@/hooks/useVoiceStatus";
|
||||
|
||||
// Wrapper component for SourceTag in toolbar to handle memoization
|
||||
const SourcesTagWrapper = React.memo(function SourcesTagWrapper({
|
||||
@@ -147,14 +144,6 @@ export default function MessageToolbar({
|
||||
(state) => state.updateCurrentSelectedNodeForDocDisplay
|
||||
);
|
||||
|
||||
// Voice mode - hide toolbar during TTS playback for this message
|
||||
const { isTTSPlaying, activeMessageNodeId, isAwaitingAutoPlaybackStart } =
|
||||
useVoiceMode();
|
||||
const { ttsEnabled } = useVoiceStatus();
|
||||
const isTTSActiveForThisMessage =
|
||||
(isTTSPlaying || isAwaitingAutoPlaybackStart) &&
|
||||
activeMessageNodeId === nodeId;
|
||||
|
||||
// Feedback modal state and handlers
|
||||
const { handleFeedbackChange } = useFeedbackController();
|
||||
const modal = useCreateModal();
|
||||
@@ -215,11 +204,6 @@ export default function MessageToolbar({
|
||||
[messageId, currentFeedback, handleFeedbackChange, modal]
|
||||
);
|
||||
|
||||
// Hide toolbar while TTS is playing for this message
|
||||
if (isTTSActiveForThisMessage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<modal.Provider>
|
||||
@@ -265,7 +249,6 @@ export default function MessageToolbar({
|
||||
<SelectButton
|
||||
icon={SvgThumbsUp}
|
||||
onClick={() => handleFeedbackClick("like")}
|
||||
variant="select-light"
|
||||
state={isFeedbackTransient("like") ? "selected" : "empty"}
|
||||
tooltip={
|
||||
currentFeedback === "like" ? "Remove Like" : "Good Response"
|
||||
@@ -275,7 +258,6 @@ export default function MessageToolbar({
|
||||
<SelectButton
|
||||
icon={SvgThumbsDown}
|
||||
onClick={() => handleFeedbackClick("dislike")}
|
||||
variant="select-light"
|
||||
state={isFeedbackTransient("dislike") ? "selected" : "empty"}
|
||||
tooltip={
|
||||
currentFeedback === "dislike"
|
||||
@@ -284,13 +266,6 @@ export default function MessageToolbar({
|
||||
}
|
||||
data-testid="AgentMessage/dislike-button"
|
||||
/>
|
||||
{ttsEnabled && (
|
||||
<TTSButton
|
||||
text={
|
||||
removeThinkingTokens(getTextContent(rawPackets)) as string
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{onRegenerate &&
|
||||
messageId !== undefined &&
|
||||
@@ -308,7 +283,7 @@ export default function MessageToolbar({
|
||||
});
|
||||
regenerator(llmDescriptor);
|
||||
}}
|
||||
foldable
|
||||
folded
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { SvgPlayCircle, SvgStop } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { useVoicePlayback } from "@/hooks/useVoicePlayback";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
interface TTSButtonProps {
|
||||
text: string;
|
||||
voice?: string;
|
||||
speed?: number;
|
||||
}
|
||||
|
||||
function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
const { isPlaying, isLoading, error, play, pause, stop } = useVoicePlayback();
|
||||
const { isTTSPlaying, isTTSLoading, isAwaitingAutoPlaybackStart, stopTTS } =
|
||||
useVoiceMode();
|
||||
|
||||
const isGlobalTTSActive =
|
||||
isTTSPlaying || isTTSLoading || isAwaitingAutoPlaybackStart;
|
||||
const isButtonPlaying = isGlobalTTSActive || isPlaying;
|
||||
const isButtonLoading = !isGlobalTTSActive && isLoading;
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isGlobalTTSActive) {
|
||||
// Stop auto-playback voice mode stream from the toolbar button.
|
||||
stopTTS({ manual: true });
|
||||
stop();
|
||||
} else if (isPlaying) {
|
||||
pause();
|
||||
} else if (isButtonLoading) {
|
||||
stop();
|
||||
} else {
|
||||
try {
|
||||
// Ensure no voice-mode stream is active before starting manual playback.
|
||||
stopTTS();
|
||||
await play(text, voice, speed);
|
||||
} catch (err) {
|
||||
console.error("TTS playback failed:", err);
|
||||
toast.error("Could not play audio");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isGlobalTTSActive,
|
||||
isPlaying,
|
||||
isButtonLoading,
|
||||
text,
|
||||
voice,
|
||||
speed,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
stopTTS,
|
||||
]);
|
||||
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
const icon = isButtonLoading
|
||||
? SimpleLoader
|
||||
: isButtonPlaying
|
||||
? SvgStop
|
||||
: SvgPlayCircle;
|
||||
|
||||
const tooltip = isButtonPlaying
|
||||
? "Stop playback"
|
||||
: isButtonLoading
|
||||
? "Loading..."
|
||||
: "Read aloud";
|
||||
|
||||
return (
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
prominence="tertiary"
|
||||
tooltip={tooltip}
|
||||
data-testid="AgentMessage/tts-button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TTSButton;
|
||||
@@ -67,10 +67,6 @@ export type MessageRenderer<
|
||||
> = React.ComponentType<{
|
||||
packets: T[];
|
||||
state: S;
|
||||
/** Node id for the message currently being rendered */
|
||||
messageNodeId?: number;
|
||||
/** True when timeline/thinking UI is already shown above this text block */
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
renderType: RenderType;
|
||||
animate: boolean;
|
||||
|
||||
@@ -166,8 +166,6 @@ function MixedContentHandler({
|
||||
chatPackets,
|
||||
imagePackets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -177,8 +175,6 @@ function MixedContentHandler({
|
||||
chatPackets: Packet[];
|
||||
imagePackets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -189,8 +185,6 @@ function MixedContentHandler({
|
||||
<MessageTextRenderer
|
||||
packets={chatPackets as ChatPacket[]}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={() => {}}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
@@ -218,8 +212,6 @@ function MixedContentHandler({
|
||||
interface RendererComponentProps {
|
||||
packets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -237,8 +229,7 @@ function areRendererPropsEqual(
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id &&
|
||||
prev.messageNodeId === next.messageNodeId
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id
|
||||
// Skip: onComplete, children (function refs), chatState (memoized upstream)
|
||||
);
|
||||
}
|
||||
@@ -247,8 +238,6 @@ function areRendererPropsEqual(
|
||||
export const RendererComponent = memo(function RendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -283,8 +272,6 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
chatPackets={chatPackets}
|
||||
imagePackets={imagePackets}
|
||||
chatState={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
@@ -305,8 +292,6 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import React, { useEffect, useMemo, useState } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
import {
|
||||
@@ -10,55 +10,6 @@ import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
markdown[mdIndex + 1] === "("
|
||||
) {
|
||||
const closeIdx = markdown.indexOf(")", mdIndex + 2);
|
||||
if (closeIdx > 0) {
|
||||
mdIndex = closeIdx + 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
cleanIndex++;
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
markdown[mdIndex] !== "\n"
|
||||
) {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
@@ -69,8 +20,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
> = ({
|
||||
packets,
|
||||
state,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
renderType,
|
||||
animate,
|
||||
@@ -87,17 +36,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
@@ -112,11 +50,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
})
|
||||
.join("");
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
useEffect(() => {
|
||||
if (!animate) {
|
||||
@@ -160,37 +93,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
|
||||
// Get content based on displayed packet count or audio progress
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
// Standard behavior when auto-playback is off
|
||||
// Get content based on displayed packet count
|
||||
const content = useMemo(() => {
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
// Only show content from packets up to displayedPacketCount
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
@@ -203,109 +112,31 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel, freeze at exactly what was already visible.
|
||||
if (wasUserCancelled) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
if (computedContent.length === 0) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
const last = lastStableSyncedContentRef.current;
|
||||
if (computedContent.startsWith(last)) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [computedContent, shouldUseAutoPlaybackSync, stopPacketSeen, stopReason]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
}
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
lastStableSyncedContentRef.current = "";
|
||||
} else if (content.length > 0) {
|
||||
lastStableSyncedContentRef.current = content;
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowSpeechWarmupIndicator =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
!isAwaitingAutoPlaybackStart &&
|
||||
content.length === 0 &&
|
||||
fullContent.length > 0 &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
}, [animate, displayedPacketCount, fullContent, packets]);
|
||||
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
stopPacketSeen ? content : content + " [*]() ",
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
return children([
|
||||
{
|
||||
icon: null,
|
||||
status: null,
|
||||
content:
|
||||
shouldShowThinkingPlaceholder || shouldShowSpeechWarmupIndicator ? (
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
content.length > 0 || packets.length > 0 ? (
|
||||
<>
|
||||
{renderedContent}
|
||||
{wasUserCancelled && (
|
||||
<Text as="p" secondaryBody text04>
|
||||
User has stopped generation
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -98,7 +98,7 @@ export default function ArtifactsTab({
|
||||
const handleWebappDownload = () => {
|
||||
if (!sessionId) return;
|
||||
const link = document.createElement("a");
|
||||
link.href = `/api/build/sessions/${sessionId}/webapp-download`;
|
||||
link.href = `/api/build/sessions/${sessionId}/webapp/download`;
|
||||
link.download = "";
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
|
||||
@@ -566,21 +566,6 @@ textarea {
|
||||
animation: fadeIn 0.2s ease-out forwards;
|
||||
}
|
||||
|
||||
/* Recording waveform animation */
|
||||
@keyframes waveform {
|
||||
0%,
|
||||
100% {
|
||||
transform: scaleY(0.3);
|
||||
}
|
||||
50% {
|
||||
transform: scaleY(1);
|
||||
}
|
||||
}
|
||||
|
||||
.animate-waveform {
|
||||
animation: waveform 0.8s ease-in-out infinite;
|
||||
}
|
||||
|
||||
.container {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import { Button } from "@opal/components";
|
||||
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import type { AppMode } from "@/providers/QueryControllerProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
@@ -58,15 +58,15 @@ const footerMarkdownComponents = {
|
||||
*/
|
||||
export default function NRFChrome() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { state, setAppMode } = useQueryController();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
const [modePopoverOpen, setModePopoverOpen] = useState(false);
|
||||
|
||||
const effectiveMode: AppMode =
|
||||
appFocus.isNewSession() && state.phase === "idle" ? state.appMode : "chat";
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
|
||||
const customFooterContent =
|
||||
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
|
||||
@@ -78,7 +78,7 @@ export default function NRFChrome() {
|
||||
isPaidEnterpriseFeaturesEnabled &&
|
||||
settings.isSearchModeAvailable &&
|
||||
appFocus.isNewSession() &&
|
||||
state.phase === "idle";
|
||||
!classification;
|
||||
|
||||
const showHeader = isMobile || showModeToggle;
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
const isStreaming = currentChatState === "streaming";
|
||||
|
||||
// Query controller for search/chat classification (EE feature)
|
||||
const { submit: submitQuery, state } = useQueryController();
|
||||
const { submit: submitQuery, classification } = useQueryController();
|
||||
|
||||
// Determine if retrieval (search) is enabled based on the agent
|
||||
const retrievalEnabled = useMemo(() => {
|
||||
@@ -186,8 +186,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
}, [liveAgent]);
|
||||
|
||||
// Check if we're in search mode
|
||||
const isSearch =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
const isSearch = classification === "search";
|
||||
|
||||
// Anchor for scroll positioning (matches ChatPage pattern)
|
||||
const anchorMessage = messageHistory.at(-2) ?? messageHistory[0];
|
||||
@@ -318,7 +317,7 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
};
|
||||
|
||||
// Use submitQuery which will classify the query and either:
|
||||
// - Route to search (sets phase to "searching"/"search-results" and shows SearchUI)
|
||||
// - Route to search (sets classification to "search" and shows SearchUI)
|
||||
// - Route to chat (calls onChat callback)
|
||||
await submitQuery(submittedMessage, onChat);
|
||||
},
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
@@ -12,9 +11,5 @@ export interface LayoutProps {
|
||||
* Sidebar and chrome are handled by sub-layouts / individual pages.
|
||||
*/
|
||||
export default function Layout({ children }: LayoutProps) {
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
<VoiceModeProvider>{children}</VoiceModeProvider>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
return <ProjectsProvider>{children}</ProjectsProvider>;
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ const SETTINGS_LAYOUT_PREFIXES = [
|
||||
ADMIN_PATHS.LLM_MODELS,
|
||||
ADMIN_PATHS.AGENTS,
|
||||
ADMIN_PATHS.USERS,
|
||||
ADMIN_PATHS.USERS_V2,
|
||||
ADMIN_PATHS.TOKEN_RATE_LIMITS,
|
||||
ADMIN_PATHS.SEARCH_SETTINGS,
|
||||
ADMIN_PATHS.DOCUMENT_PROCESSING,
|
||||
|
||||
@@ -11,14 +11,13 @@ import rehypeHighlight from "rehype-highlight";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { transformLinkUri } from "@/lib/utils";
|
||||
import { cn, transformLinkUri } from "@/lib/utils";
|
||||
|
||||
type MinimalMarkdownComponentOverrides = Partial<Components>;
|
||||
|
||||
interface MinimalMarkdownProps {
|
||||
content: string;
|
||||
className?: string;
|
||||
style?: CSSProperties;
|
||||
showHeader?: boolean;
|
||||
/**
|
||||
* Override specific markdown renderers.
|
||||
@@ -30,7 +29,6 @@ interface MinimalMarkdownProps {
|
||||
export default function MinimalMarkdown({
|
||||
content,
|
||||
className = "",
|
||||
style,
|
||||
showHeader = true,
|
||||
components,
|
||||
}: MinimalMarkdownProps) {
|
||||
@@ -63,19 +61,17 @@ export default function MinimalMarkdown({
|
||||
}, [content, components, showHeader]);
|
||||
|
||||
return (
|
||||
<div style={style || {}} className={`${className}`}>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-sm break-words"
|
||||
components={markdownComponents}
|
||||
rehypePlugins={[rehypeHighlight, rehypeKatex]}
|
||||
remarkPlugins={[
|
||||
remarkGfm,
|
||||
[remarkMath, { singleDollarTextMath: false }],
|
||||
]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
<ReactMarkdown
|
||||
className={cn(
|
||||
"prose dark:prose-invert max-w-full text-sm break-words",
|
||||
className
|
||||
)}
|
||||
components={markdownComponents}
|
||||
rehypePlugins={[rehypeHighlight, rehypeKatex]}
|
||||
remarkPlugins={[remarkGfm, [remarkMath, { singleDollarTextMath: false }]]}
|
||||
urlTransform={transformLinkUri}
|
||||
>
|
||||
{content}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -39,8 +39,6 @@ import document360Icon from "@public/Document360.png";
|
||||
import dropboxIcon from "@public/Dropbox.png";
|
||||
import drupalwikiIcon from "@public/DrupalWiki.png";
|
||||
import egnyteIcon from "@public/Egnyte.png";
|
||||
import elevenLabsDarkSVG from "@public/ElevenLabsDark.svg";
|
||||
import elevenLabsSVG from "@public/ElevenLabs.svg";
|
||||
import firefliesIcon from "@public/Fireflies.png";
|
||||
import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
@@ -845,9 +843,6 @@ export const Document360Icon = createLogoIcon(document360Icon);
|
||||
export const DropboxIcon = createLogoIcon(dropboxIcon);
|
||||
export const DrupalWikiIcon = createLogoIcon(drupalwikiIcon);
|
||||
export const EgnyteIcon = createLogoIcon(egnyteIcon);
|
||||
export const ElevenLabsIcon = createLogoIcon(elevenLabsSVG, {
|
||||
darkSrc: elevenLabsDarkSVG,
|
||||
});
|
||||
export const FirefliesIcon = createLogoIcon(firefliesIcon);
|
||||
export const FreshdeskIcon = createLogoIcon(freshdeskIcon);
|
||||
export const GeminiIcon = createLogoIcon(geminiSVG);
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useMemo, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { formatElapsedTime } from "@/lib/dateUtils";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgMicrophone,
|
||||
SvgMicrophoneOff,
|
||||
SvgVolume,
|
||||
SvgVolumeOff,
|
||||
} from "@opal/icons";
|
||||
|
||||
// Recording waveform constants
|
||||
const RECORDING_BAR_COUNT = 120;
|
||||
const MIN_BAR_HEIGHT = 2;
|
||||
const MAX_BAR_HEIGHT = 16;
|
||||
|
||||
// Speaking waveform constants
|
||||
const SPEAKING_BAR_COUNT = 28;
|
||||
|
||||
interface WaveformProps {
|
||||
/** Visual style and behavior variant */
|
||||
variant: "speaking" | "recording";
|
||||
/** Whether the waveform is actively animating */
|
||||
isActive: boolean;
|
||||
/** Whether audio is muted */
|
||||
isMuted?: boolean;
|
||||
/** Current microphone audio level (0-1), only used for recording variant */
|
||||
audioLevel?: number;
|
||||
/** Callback when mute button is clicked */
|
||||
onMuteToggle?: () => void;
|
||||
}
|
||||
|
||||
function Waveform({
|
||||
variant,
|
||||
isActive,
|
||||
isMuted = false,
|
||||
audioLevel = 0,
|
||||
onMuteToggle,
|
||||
}: WaveformProps) {
|
||||
// ─── Recording variant state ───────────────────────────────────────────────
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
const [barHeights, setBarHeights] = useState<number[]>(
|
||||
() => new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
const animationRef = useRef<number | null>(null);
|
||||
const lastPushTimeRef = useRef(0);
|
||||
const audioLevelRef = useRef(audioLevel);
|
||||
audioLevelRef.current = audioLevel;
|
||||
|
||||
// ─── Speaking variant bars ─────────────────────────────────────────────────
|
||||
const speakingBars = useMemo(() => {
|
||||
return Array.from({ length: SPEAKING_BAR_COUNT }, (_, i) => ({
|
||||
id: i,
|
||||
// Create a natural wave pattern with height variation
|
||||
baseHeight: Math.sin(i * 0.4) * 5 + 8,
|
||||
delay: i * 0.025,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
// ─── Recording: Timer effect ───────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setElapsedSeconds(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
setElapsedSeconds((prev) => prev + 1);
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [variant, isActive]);
|
||||
|
||||
// ─── Recording: Audio level visualization effect ───────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setBarHeights(
|
||||
new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
lastPushTimeRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateBars = (timestamp: number) => {
|
||||
// Push a new bar roughly every 50ms (~20fps scrolling)
|
||||
if (timestamp - lastPushTimeRef.current >= 50) {
|
||||
lastPushTimeRef.current = timestamp;
|
||||
const level = isMuted ? 0 : audioLevelRef.current;
|
||||
const height =
|
||||
MIN_BAR_HEIGHT + level * (MAX_BAR_HEIGHT - MIN_BAR_HEIGHT);
|
||||
|
||||
setBarHeights((prev) => {
|
||||
const next = prev.slice(1);
|
||||
next.push(height);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
};
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
animationRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [variant, isActive, isMuted]);
|
||||
|
||||
const formattedTime = useMemo(
|
||||
() => formatElapsedTime(elapsedSeconds),
|
||||
[elapsedSeconds]
|
||||
);
|
||||
|
||||
if (!isActive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// ─── Speaking variant render ───────────────────────────────────────────────
|
||||
if (variant === "speaking") {
|
||||
return (
|
||||
<div className="flex items-center gap-0.5 p-1.5 bg-background-tint-00 rounded-16 shadow-01">
|
||||
{/* Waveform container */}
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12 max-w-[144px] min-h-[32px]">
|
||||
<div className="flex items-center p-1">
|
||||
{/* Waveform bars */}
|
||||
<div className="flex items-center justify-center gap-[2px] h-4 w-[120px] overflow-hidden">
|
||||
{speakingBars.map((bar) => (
|
||||
<div
|
||||
key={bar.id}
|
||||
className={cn(
|
||||
"w-[3px] rounded-full",
|
||||
isMuted ? "bg-text-03" : "bg-theme-blue-05",
|
||||
!isMuted && "animate-waveform"
|
||||
)}
|
||||
style={{
|
||||
height: isMuted ? "2px" : `${bar.baseHeight}px`,
|
||||
animationDelay: isMuted ? undefined : `${bar.delay}s`,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="w-0.5 self-stretch bg-border-02" />
|
||||
|
||||
{/* Volume button */}
|
||||
{onMuteToggle && (
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12">
|
||||
<Button
|
||||
icon={isMuted ? SvgVolumeOff : SvgVolume}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
tooltip={isMuted ? "Unmute" : "Mute"}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Recording variant render ──────────────────────────────────────────────
|
||||
return (
|
||||
<div className="flex items-center gap-3 px-3 py-2 bg-background-tint-00 rounded-12 min-h-[32px]">
|
||||
{/* Waveform visualization driven by real audio levels */}
|
||||
<div className="flex-1 flex items-center justify-between h-4 overflow-hidden">
|
||||
{barHeights.map((height, i) => (
|
||||
<div
|
||||
key={i}
|
||||
className="w-[1.5px] bg-text-03 rounded-full shrink-0 transition-[height] duration-75"
|
||||
style={{ height: `${height}px` }}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Timer */}
|
||||
<span className="font-mono text-xs text-text-03 tabular-nums shrink-0">
|
||||
{formattedTime}
|
||||
</span>
|
||||
|
||||
{/* Mute button */}
|
||||
{onMuteToggle && (
|
||||
<Button
|
||||
icon={isMuted ? SvgMicrophoneOff : SvgMicrophone}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
aria-label={isMuted ? "Unmute microphone" : "Mute microphone"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Waveform;
|
||||
55
web/src/ee/providers/AppModeProvider.tsx
Normal file
55
web/src/ee/providers/AppModeProvider.tsx
Normal file
@@ -0,0 +1,55 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useCallback, useEffect } from "react";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { AppModeContext, AppMode } from "@/providers/AppModeProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
export interface AppModeProviderProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider for application mode (Search/Chat).
|
||||
*
|
||||
* This controls how user queries are handled:
|
||||
* - **search**: Forces search mode - quick document lookup
|
||||
* - **chat**: Forces chat mode - conversation with follow-up questions
|
||||
*
|
||||
* The initial mode is read from the user's persisted `default_app_mode` preference.
|
||||
* When search mode is unavailable (admin setting or no connectors), the mode is locked to "chat".
|
||||
*/
|
||||
export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { user } = useUser();
|
||||
const { isSearchModeAvailable } = useSettingsContext();
|
||||
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
const [appMode, setAppModeState] = useState<AppMode>("chat");
|
||||
|
||||
useEffect(() => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) {
|
||||
setAppModeState("chat");
|
||||
return;
|
||||
}
|
||||
|
||||
if (persistedMode) {
|
||||
setAppModeState(persistedMode.toLowerCase() as AppMode);
|
||||
}
|
||||
}, [isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) return;
|
||||
setAppModeState(mode);
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable]
|
||||
);
|
||||
|
||||
return (
|
||||
<AppModeContext.Provider value={{ appMode, setAppMode }}>
|
||||
{children}
|
||||
</AppModeContext.Provider>
|
||||
);
|
||||
}
|
||||
@@ -8,15 +8,14 @@ import {
|
||||
SearchFullResponse,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { classifyQuery, searchDocuments } from "@/ee/lib/search/svc";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import {
|
||||
QueryControllerContext,
|
||||
QueryClassification,
|
||||
QueryControllerValue,
|
||||
QueryState,
|
||||
AppMode,
|
||||
} from "@/providers/QueryControllerProvider";
|
||||
|
||||
interface QueryControllerProviderProps {
|
||||
@@ -26,53 +25,19 @@ interface QueryControllerProviderProps {
|
||||
export function QueryControllerProvider({
|
||||
children,
|
||||
}: QueryControllerProviderProps) {
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable: searchUiEnabled } = settings;
|
||||
const { user } = useUser();
|
||||
|
||||
// ── Merged query state (discriminated union) ──────────────────────────
|
||||
const [state, setState] = useState<QueryState>({
|
||||
phase: "idle",
|
||||
appMode: "chat",
|
||||
});
|
||||
|
||||
// Persistent app-mode preference — survives phase transitions and is
|
||||
// used to restore the correct mode when resetting back to idle.
|
||||
const appModeRef = useRef<AppMode>("chat");
|
||||
|
||||
// ── App mode sync from user preferences ───────────────────────────────
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
|
||||
useEffect(() => {
|
||||
let mode: AppMode = "chat";
|
||||
if (isPaidEnterpriseFeaturesEnabled && searchUiEnabled && persistedMode) {
|
||||
const lower = persistedMode.toLowerCase();
|
||||
mode = (["auto", "search", "chat"] as const).includes(lower as AppMode)
|
||||
? (lower as AppMode)
|
||||
: "chat";
|
||||
}
|
||||
appModeRef.current = mode;
|
||||
setState((prev) =>
|
||||
prev.phase === "idle" ? { phase: "idle", appMode: mode } : prev
|
||||
);
|
||||
}, [isPaidEnterpriseFeaturesEnabled, searchUiEnabled, persistedMode]);
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !searchUiEnabled) return;
|
||||
setState((prev) => {
|
||||
if (prev.phase !== "idle") return prev;
|
||||
appModeRef.current = mode;
|
||||
return { phase: "idle", appMode: mode };
|
||||
});
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled, searchUiEnabled]
|
||||
);
|
||||
|
||||
// ── Ancillary state ───────────────────────────────────────────────────
|
||||
// Query state
|
||||
const [query, setQuery] = useState<string | null>(null);
|
||||
const [classification, setClassification] =
|
||||
useState<QueryClassification>(null);
|
||||
const [isClassifying, setIsClassifying] = useState(false);
|
||||
|
||||
// Search state
|
||||
const [searchResults, setSearchResults] = useState<SearchDocWithContent[]>(
|
||||
[]
|
||||
);
|
||||
@@ -86,7 +51,7 @@ export function QueryControllerProvider({
|
||||
const searchAbortRef = useRef<AbortController | null>(null);
|
||||
|
||||
/**
|
||||
* Perform document search (pure data-fetching, no phase side effects)
|
||||
* Perform document search
|
||||
*/
|
||||
const performSearch = useCallback(
|
||||
async (searchQuery: string, filters?: BaseFilters): Promise<void> => {
|
||||
@@ -120,15 +85,19 @@ export function QueryControllerProvider({
|
||||
setLlmSelectedDocIds(response.llm_selected_doc_ids ?? null);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
throw err;
|
||||
return;
|
||||
}
|
||||
|
||||
setError("Document search failed. Please try again.");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
} finally {
|
||||
// After we've performed a search, we automatically switch to "search" mode.
|
||||
// This is a "sticky" implementation; on purpose.
|
||||
setAppMode("search");
|
||||
}
|
||||
},
|
||||
[]
|
||||
[setAppMode]
|
||||
);
|
||||
|
||||
/**
|
||||
@@ -143,6 +112,8 @@ export function QueryControllerProvider({
|
||||
const controller = new AbortController();
|
||||
classifyAbortRef.current = controller;
|
||||
|
||||
setIsClassifying(true);
|
||||
|
||||
try {
|
||||
const response: SearchFlowClassificationResponse = await classifyQuery(
|
||||
classifyQueryText,
|
||||
@@ -158,6 +129,8 @@ export function QueryControllerProvider({
|
||||
|
||||
setError("Query classification failed. Falling back to chat.");
|
||||
return "chat";
|
||||
} finally {
|
||||
setIsClassifying(false);
|
||||
}
|
||||
},
|
||||
[]
|
||||
@@ -175,51 +148,62 @@ export function QueryControllerProvider({
|
||||
setQuery(submitQuery);
|
||||
setError(null);
|
||||
|
||||
const currentAppMode = appModeRef.current;
|
||||
|
||||
// Always route through chat if:
|
||||
// 1. Not Enterprise Enabled
|
||||
// 2. Admin has disabled the Search UI
|
||||
// 3. Not in the "New Session" tab
|
||||
// 4. In "New Session" tab but app-mode is "Chat"
|
||||
// 1.
|
||||
// We always route through chat if we're not Enterprise Enabled.
|
||||
//
|
||||
// 2.
|
||||
// We always route through chat if the admin has disabled the Search UI.
|
||||
//
|
||||
// 3.
|
||||
// We only go down the classification route if we're in the "New Session" tab.
|
||||
// Everywhere else, we always use the chat-flow.
|
||||
//
|
||||
// 4.
|
||||
// If we're in the "New Session" tab and the app-mode is "Chat", we continue with the chat-flow anyways.
|
||||
if (
|
||||
!isPaidEnterpriseFeaturesEnabled ||
|
||||
!searchUiEnabled ||
|
||||
!appFocus.isNewSession() ||
|
||||
currentAppMode === "chat"
|
||||
appMode === "chat"
|
||||
) {
|
||||
setState({ phase: "chat" });
|
||||
setClassification("chat");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
return;
|
||||
}
|
||||
|
||||
// Search mode: immediately show SearchUI with loading state
|
||||
if (currentAppMode === "search") {
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(submitQuery, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
if (appMode === "search") {
|
||||
await performSearch(submitQuery, filters);
|
||||
setClassification("search");
|
||||
return;
|
||||
}
|
||||
|
||||
// # Note (@raunakab)
|
||||
//
|
||||
// Interestingly enough, for search, we do:
|
||||
// 1. setClassification("search")
|
||||
// 2. performSearch
|
||||
//
|
||||
// But for chat, we do:
|
||||
// 1. performChat
|
||||
// 2. setClassification("chat")
|
||||
//
|
||||
// The ChatUI has a nice loading UI, so it's fine for us to prematurely set the
|
||||
// classification-state before the chat has finished loading.
|
||||
//
|
||||
// However, the SearchUI does not. Prematurely setting the classification-state
|
||||
// will lead to a slightly ugly UI.
|
||||
|
||||
// Auto mode: classify first, then route
|
||||
setState({ phase: "classifying" });
|
||||
try {
|
||||
const result = await performClassification(submitQuery);
|
||||
|
||||
if (result === "search") {
|
||||
setState({ phase: "searching" });
|
||||
await performSearch(submitQuery, filters);
|
||||
setState({ phase: "search-results" });
|
||||
appModeRef.current = "search";
|
||||
setClassification("search");
|
||||
} else {
|
||||
setState({ phase: "chat" });
|
||||
setClassification("chat");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
@@ -229,13 +213,14 @@ export function QueryControllerProvider({
|
||||
return;
|
||||
}
|
||||
|
||||
setState({ phase: "chat" });
|
||||
setClassification("chat");
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
onChat(submitQuery);
|
||||
}
|
||||
},
|
||||
[
|
||||
appMode,
|
||||
appFocus,
|
||||
performClassification,
|
||||
performSearch,
|
||||
@@ -250,14 +235,7 @@ export function QueryControllerProvider({
|
||||
const refineSearch = useCallback(
|
||||
async (filters: BaseFilters): Promise<void> => {
|
||||
if (!query) return;
|
||||
setState({ phase: "searching" });
|
||||
try {
|
||||
await performSearch(query, filters);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") return;
|
||||
throw err;
|
||||
}
|
||||
setState({ phase: "search-results" });
|
||||
await performSearch(query, filters);
|
||||
},
|
||||
[query, performSearch]
|
||||
);
|
||||
@@ -276,7 +254,7 @@ export function QueryControllerProvider({
|
||||
}
|
||||
|
||||
setQuery(null);
|
||||
setState({ phase: "idle", appMode: appModeRef.current });
|
||||
setClassification(null);
|
||||
setSearchResults([]);
|
||||
setLlmSelectedDocIds(null);
|
||||
setError(null);
|
||||
@@ -284,8 +262,8 @@ export function QueryControllerProvider({
|
||||
|
||||
const value: QueryControllerValue = useMemo(
|
||||
() => ({
|
||||
state,
|
||||
setAppMode,
|
||||
classification,
|
||||
isClassifying,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -294,8 +272,8 @@ export function QueryControllerProvider({
|
||||
reset,
|
||||
}),
|
||||
[
|
||||
state,
|
||||
setAppMode,
|
||||
classification,
|
||||
isClassifying,
|
||||
searchResults,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
@@ -305,7 +283,7 @@ export function QueryControllerProvider({
|
||||
]
|
||||
);
|
||||
|
||||
// Sync state with navigation context
|
||||
// Sync classification state with navigation context
|
||||
useEffect(reset, [appFocus, reset]);
|
||||
|
||||
return (
|
||||
|
||||
@@ -56,7 +56,7 @@ export default function SearchCard({
|
||||
|
||||
return (
|
||||
<Interactive.Stateless onClick={handleClick} prominence="secondary">
|
||||
<Interactive.Container heightVariant="fit" widthVariant="full">
|
||||
<Interactive.Container heightVariant="fit">
|
||||
<Section alignItems="start" gap={0} padding={0.25}>
|
||||
{/* Title Row */}
|
||||
<Section
|
||||
|
||||
@@ -18,17 +18,16 @@ import { getTimeFilterDate, TimeFilter } from "@/lib/time";
|
||||
import useTags from "@/hooks/useTags";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { SvgCheck, SvgClock, SvgTag } from "@opal/icons";
|
||||
import FilterButton from "@/refresh-components/buttons/FilterButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import useFilter from "@/hooks/useFilter";
|
||||
import { LineItemButton } from "@opal/components";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
@@ -52,17 +51,22 @@ const TIME_FILTER_OPTIONS: { value: TimeFilter; label: string }[] = [
|
||||
{ value: "year", label: "Past year" },
|
||||
];
|
||||
|
||||
// ============================================================================
|
||||
// SearchResults Component (default export)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Component for displaying search results with source filter sidebar.
|
||||
*/
|
||||
export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
// Available tags from backend
|
||||
const { tags: availableTags } = useTags();
|
||||
const {
|
||||
state,
|
||||
searchResults: results,
|
||||
llmSelectedDocIds,
|
||||
error,
|
||||
refineSearch: onRefineSearch,
|
||||
} = useQueryController();
|
||||
|
||||
const prevErrorRef = useRef<string | null>(null);
|
||||
|
||||
// Show a toast notification when a new error occurs
|
||||
@@ -193,15 +197,6 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
|
||||
const showEmpty = !error && results.length === 0;
|
||||
|
||||
// Show a centered spinner while search is in-flight (after all hooks)
|
||||
if (state.phase === "searching") {
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex items-center justify-center">
|
||||
<SimpleLoader />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex-1 min-h-0 w-full flex flex-col gap-3">
|
||||
{/* ── Top row: Filters + Result count ── */}
|
||||
@@ -231,19 +226,18 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<Popover.Content align="start" width="md">
|
||||
<PopoverMenu>
|
||||
{TIME_FILTER_OPTIONS.map((opt) => (
|
||||
<LineItemButton
|
||||
<LineItem
|
||||
key={opt.value}
|
||||
onClick={() => {
|
||||
setTimeFilter(opt.value);
|
||||
setTimeFilterOpen(false);
|
||||
onRefineSearch(buildFilters({ time: opt.value }));
|
||||
}}
|
||||
state={timeFilter === opt.value ? "selected" : "empty"}
|
||||
selected={timeFilter === opt.value}
|
||||
icon={timeFilter === opt.value ? SvgCheck : SvgClock}
|
||||
title={opt.label}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
>
|
||||
{opt.label}
|
||||
</LineItem>
|
||||
))}
|
||||
</PopoverMenu>
|
||||
</Popover.Content>
|
||||
@@ -284,7 +278,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
t.tag_value === tag.tag_value
|
||||
);
|
||||
return (
|
||||
<LineItemButton
|
||||
<LineItem
|
||||
key={`${tag.tag_key}=${tag.tag_value}`}
|
||||
onClick={() => {
|
||||
const next = isSelected
|
||||
@@ -297,12 +291,11 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
setSelectedTags(next);
|
||||
onRefineSearch(buildFilters({ tags: next }));
|
||||
}}
|
||||
state={isSelected ? "selected" : "empty"}
|
||||
selected={isSelected}
|
||||
icon={isSelected ? SvgCheck : SvgTag}
|
||||
title={tag.tag_value}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
>
|
||||
{tag.tag_value}
|
||||
</LineItem>
|
||||
);
|
||||
})}
|
||||
</PopoverMenu>
|
||||
@@ -364,7 +357,7 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
<div className="flex-1 min-h-0 overflow-y-auto flex flex-col gap-4 px-1">
|
||||
<Section gap={0.25} height="fit">
|
||||
{sourcesWithMeta.map(({ source, meta, count }) => (
|
||||
<LineItemButton
|
||||
<LineItem
|
||||
key={source}
|
||||
icon={(props) => (
|
||||
<SourceIcon
|
||||
@@ -374,15 +367,12 @@ export default function SearchUI({ onDocumentClick }: SearchResultsProps) {
|
||||
/>
|
||||
)}
|
||||
onClick={() => handleSourceToggle(source)}
|
||||
state={
|
||||
selectedSources.includes(source) ? "selected" : "empty"
|
||||
}
|
||||
title={meta.displayName}
|
||||
selectVariant="select-heavy"
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
selected={selectedSources.includes(source)}
|
||||
emphasized
|
||||
rightChildren={<Text text03>{count}</Text>}
|
||||
/>
|
||||
>
|
||||
{meta.displayName}
|
||||
</LineItem>
|
||||
))}
|
||||
</Section>
|
||||
</div>
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
//
|
||||
// This is useful in determining what `SidebarTab` should be active, for example.
|
||||
|
||||
import { useMemo } from "react";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
|
||||
@@ -67,25 +66,31 @@ export default function useAppFocus(): AppFocus {
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const chatId = searchParams.get(SEARCH_PARAM_NAMES.CHAT_ID);
|
||||
const agentId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
|
||||
const projectId = searchParams.get(SEARCH_PARAM_NAMES.PROJECT_ID);
|
||||
// Check if we're viewing a shared chat
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
|
||||
// Memoize on the values that determine which AppFocus is constructed.
|
||||
// AppFocus is immutable, so same inputs → same instance.
|
||||
return useMemo(() => {
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
return new AppFocus("new-session");
|
||||
}, [pathname, chatId, agentId, projectId]);
|
||||
// Check if we're on the user settings page
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
}
|
||||
|
||||
// Check if we're on the agents page
|
||||
if (pathname.startsWith("/app/agents")) {
|
||||
return new AppFocus("more-agents");
|
||||
}
|
||||
|
||||
// Check search params for chat, agent, or project
|
||||
const chatId = searchParams.get(SEARCH_PARAM_NAMES.CHAT_ID);
|
||||
if (chatId) return new AppFocus({ type: "chat", id: chatId });
|
||||
|
||||
const agentId = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
|
||||
if (agentId) return new AppFocus({ type: "agent", id: agentId });
|
||||
|
||||
const projectId = searchParams.get(SEARCH_PARAM_NAMES.PROJECT_ID);
|
||||
if (projectId) return new AppFocus({ type: "project", id: projectId });
|
||||
|
||||
// No search params means we're on a new session
|
||||
return new AppFocus("new-session");
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user