mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-15 04:32:39 +00:00
Compare commits
19 Commits
bo/hook
...
nikg/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
091ba93c0a | ||
|
|
8545045339 | ||
|
|
9516e61305 | ||
|
|
fa70a22b9d | ||
|
|
5bad364d01 | ||
|
|
a279baad44 | ||
|
|
b3949d37d9 | ||
|
|
16e2b18a3f | ||
|
|
dd55795d0e | ||
|
|
26dd44559b | ||
|
|
d433a86d01 | ||
|
|
57cb2f6920 | ||
|
|
eba021c221 | ||
|
|
945c27268c | ||
|
|
463cce7d76 | ||
|
|
a3fe5bff39 | ||
|
|
ad070f61a9 | ||
|
|
aade59215c | ||
|
|
6c101ee90d |
3
.github/CODEOWNERS
vendored
3
.github/CODEOWNERS
vendored
@@ -8,6 +8,3 @@
|
||||
# Agent context files
|
||||
/CLAUDE.md @Weves
|
||||
/AGENTS.md @Weves
|
||||
|
||||
# Beta cherry-pick workflow owners
|
||||
/.github/workflows/post-merge-beta-cherry-pick.yml @justin-tahara @jmelahman
|
||||
|
||||
31
.github/actions/slack-notify/action.yml
vendored
31
.github/actions/slack-notify/action.yml
vendored
@@ -1,14 +1,11 @@
|
||||
name: "Slack Notify"
|
||||
description: "Sends a Slack notification for workflow events"
|
||||
name: "Slack Notify on Failure"
|
||||
description: "Sends a Slack notification when a workflow fails"
|
||||
inputs:
|
||||
webhook-url:
|
||||
description: "Slack webhook URL (can also use SLACK_WEBHOOK_URL env var)"
|
||||
required: false
|
||||
details:
|
||||
description: "Additional message body content"
|
||||
required: false
|
||||
failed-jobs:
|
||||
description: "Deprecated alias for details"
|
||||
description: "List of failed job names (newline-separated)"
|
||||
required: false
|
||||
title:
|
||||
description: "Title for the notification"
|
||||
@@ -24,7 +21,6 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
|
||||
DETAILS: ${{ inputs.details }}
|
||||
FAILED_JOBS: ${{ inputs.failed-jobs }}
|
||||
TITLE: ${{ inputs.title }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
@@ -48,18 +44,6 @@ runs:
|
||||
REF_NAME="$GITHUB_REF_NAME"
|
||||
fi
|
||||
|
||||
if [ -z "$DETAILS" ]; then
|
||||
DETAILS="$FAILED_JOBS"
|
||||
fi
|
||||
|
||||
normalize_multiline() {
|
||||
printf '%s' "$1" | awk 'BEGIN { ORS=""; first=1 } { if (!first) printf "\\n"; printf "%s", $0; first=0 }'
|
||||
}
|
||||
|
||||
DETAILS="$(normalize_multiline "$DETAILS")"
|
||||
REF_NAME="$(normalize_multiline "$REF_NAME")"
|
||||
TITLE="$(normalize_multiline "$TITLE")"
|
||||
|
||||
# Escape JSON special characters
|
||||
escape_json() {
|
||||
local input="$1"
|
||||
@@ -75,12 +59,12 @@ runs:
|
||||
}
|
||||
|
||||
REF_NAME_ESC=$(escape_json "$REF_NAME")
|
||||
DETAILS_ESC=$(escape_json "$DETAILS")
|
||||
FAILED_JOBS_ESC=$(escape_json "$FAILED_JOBS")
|
||||
WORKFLOW_URL_ESC=$(escape_json "$WORKFLOW_URL")
|
||||
TITLE_ESC=$(escape_json "$TITLE")
|
||||
|
||||
# Build JSON payload piece by piece
|
||||
# Note: DETAILS_ESC already contains \n sequences that should remain as \n in JSON
|
||||
# Note: FAILED_JOBS_ESC already contains \n sequences that should remain as \n in JSON
|
||||
PAYLOAD="{"
|
||||
PAYLOAD="${PAYLOAD}\"text\":\"${TITLE_ESC}\","
|
||||
PAYLOAD="${PAYLOAD}\"blocks\":[{"
|
||||
@@ -95,10 +79,10 @@ runs:
|
||||
PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Run ID:*\\n#${RUN_NUMBER}\"}"
|
||||
PAYLOAD="${PAYLOAD}]"
|
||||
PAYLOAD="${PAYLOAD}}"
|
||||
if [ -n "$DETAILS" ]; then
|
||||
if [ -n "$FAILED_JOBS" ]; then
|
||||
PAYLOAD="${PAYLOAD},{"
|
||||
PAYLOAD="${PAYLOAD}\"type\":\"section\","
|
||||
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"${DETAILS_ESC}\"}"
|
||||
PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"*Failed Jobs:*\\n${FAILED_JOBS_ESC}\"}"
|
||||
PAYLOAD="${PAYLOAD}}"
|
||||
fi
|
||||
PAYLOAD="${PAYLOAD},{"
|
||||
@@ -115,3 +99,4 @@ runs:
|
||||
curl -X POST -H 'Content-type: application/json' \
|
||||
--data "$PAYLOAD" \
|
||||
"$SLACK_WEBHOOK_URL"
|
||||
|
||||
|
||||
104
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
104
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
@@ -37,27 +37,10 @@ jobs:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
MERGED_BY: ${{ github.event.pull_request.merged_by.login }}
|
||||
# Explicit merger allowlist used because pull_request_target runs with
|
||||
# the default GITHUB_TOKEN, which cannot reliably read org/team
|
||||
# membership for this repository context.
|
||||
ALLOWED_MERGERS: |
|
||||
acaprau
|
||||
bo-onyx
|
||||
danelegend
|
||||
duo-onyx
|
||||
evan-onyx
|
||||
jessicasingh7
|
||||
jmelahman
|
||||
joachim-danswer
|
||||
justin-tahara
|
||||
nmgarza5
|
||||
raunakab
|
||||
rohoswagger
|
||||
subash-mohan
|
||||
trial2onyx
|
||||
wenxi-onyx
|
||||
weves
|
||||
yuhongsun96
|
||||
# GitHub team slug authorized to trigger cherry-picks (e.g. "core-eng").
|
||||
# For private/secret teams the GITHUB_TOKEN may need org:read scope;
|
||||
# visible teams work with the default token.
|
||||
ALLOWED_TEAM: "onyx-core-team"
|
||||
run: |
|
||||
echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT"
|
||||
@@ -81,11 +64,19 @@ jobs:
|
||||
|
||||
echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
normalized_merged_by="$(printf '%s' "${MERGED_BY}" | tr '[:upper:]' '[:lower:]')"
|
||||
normalized_allowed_mergers="$(printf '%s\n' "${ALLOWED_MERGERS}" | tr '[:upper:]' '[:lower:]')"
|
||||
if ! printf '%s\n' "${normalized_allowed_mergers}" | grep -Fxq "${normalized_merged_by}"; then
|
||||
echo "gate_error=not-allowed-merger" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::${MERGED_BY} is not in the explicit cherry-pick merger allowlist. Failing cherry-pick gate."
|
||||
member_state_file="$(mktemp)"
|
||||
member_err_file="$(mktemp)"
|
||||
if ! gh api "orgs/${GITHUB_REPOSITORY_OWNER}/teams/${ALLOWED_TEAM}/memberships/${MERGED_BY}" --jq '.state' >"${member_state_file}" 2>"${member_err_file}"; then
|
||||
api_err="$(tr '\n' ' ' < "${member_err_file}" | sed 's/[[:space:]]\+/ /g' | cut -c1-300)"
|
||||
echo "gate_error=team-api-error" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::Team membership API call failed for ${MERGED_BY} in ${ALLOWED_TEAM}: ${api_err}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
member_state="$(cat "${member_state_file}")"
|
||||
if [ "${member_state}" != "active" ]; then
|
||||
echo "gate_error=not-team-member" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::${MERGED_BY} is not an active member of team ${ALLOWED_TEAM} (state: ${member_state}). Failing cherry-pick gate."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
@@ -99,7 +90,6 @@ jobs:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
cherry_pick_pr_url: ${{ steps.run_cherry_pick.outputs.pr_url }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
@@ -147,11 +137,7 @@ jobs:
|
||||
fi
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
pr_url="$(sed -n 's/^.*PR created successfully: \(https:\/\/github\.com\/[^[:space:]]\+\/pull\/[0-9]\+\).*$/\1/p' "$output_file" | tail -n 1)"
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
if [ -n "${pr_url}" ]; then
|
||||
echo "pr_url=${pr_url}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
@@ -177,54 +163,6 @@ jobs:
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-success:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
- cherry-pick-to-latest-release
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Fail if Slack webhook secret is missing
|
||||
env:
|
||||
CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
run: |
|
||||
if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then
|
||||
echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build cherry-pick success summary
|
||||
id: success-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
CHERRY_PICK_PR_URL: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_pr_url }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
details="*Cherry-pick PR opened successfully.*\\n• source PR: ${source_pr_url}"
|
||||
if [ -n "${CHERRY_PICK_PR_URL}" ]; then
|
||||
details="${details}\\n• cherry-pick PR: ${CHERRY_PICK_PR_URL}"
|
||||
fi
|
||||
if [ -n "${MERGE_COMMIT_SHA}" ]; then
|
||||
details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
fi
|
||||
|
||||
echo "details=${details}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick success
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
details: ${{ steps.success-summary.outputs.details }}
|
||||
title: "✅ Automated Cherry-Pick PR Opened"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
@@ -261,8 +199,10 @@ jobs:
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then
|
||||
reason_text="requested cherry-pick but merge commit SHA was missing"
|
||||
elif [ "${GATE_ERROR}" = "not-allowed-merger" ]; then
|
||||
reason_text="merger is not in the explicit cherry-pick allowlist"
|
||||
elif [ "${GATE_ERROR}" = "team-api-error" ]; then
|
||||
reason_text="team membership lookup failed while validating cherry-pick permissions"
|
||||
elif [ "${GATE_ERROR}" = "not-team-member" ]; then
|
||||
reason_text="merger is not an active member of the allowed team"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then
|
||||
reason_text="failed to capture cherry-pick output for classification"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
@@ -289,6 +229,6 @@ jobs:
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
details: ${{ steps.failure-summary.outputs.jobs }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""add_hook_and_hook_execution_log_tables
|
||||
|
||||
Revision ID: 689433b0d8de
|
||||
Revises: 93a2e195e25c
|
||||
Create Date: 2026-03-13 11:25:06.547474
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "689433b0d8de"
|
||||
down_revision = "93a2e195e25c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hook",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"hook_point",
|
||||
sa.Enum("document_ingestion", "query_processing", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("endpoint_url", sa.Text(), nullable=True),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("is_reachable", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"fail_strategy",
|
||||
sa.Enum("hard", "soft", native_enum=False),
|
||||
nullable=False,
|
||||
server_default="hard",
|
||||
),
|
||||
sa.Column("timeout_seconds", sa.Float(), nullable=False, server_default="30.0"),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"deleted", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("creator_id", PGUUID(as_uuid=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["creator_id"], ["user.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hook_one_active_per_point",
|
||||
"hook",
|
||||
["hook_point"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_active = true AND deleted = false"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"hook_execution_log",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("hook_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"hook_point",
|
||||
sa.Enum("document_ingestion", "query_processing", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("status_code", sa.Integer(), nullable=True),
|
||||
sa.Column("duration_ms", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["hook_id"], ["hook.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_hook_execution_log_hook_id", "hook_execution_log", ["hook_id"])
|
||||
op.create_index(
|
||||
"ix_hook_execution_log_created_at", "hook_execution_log", ["created_at"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_hook_execution_log_created_at", table_name="hook_execution_log")
|
||||
op.drop_index("ix_hook_execution_log_hook_id", table_name="hook_execution_log")
|
||||
op.drop_table("hook_execution_log")
|
||||
|
||||
op.drop_index("ix_hook_one_active_per_point", table_name="hook")
|
||||
op.drop_table("hook")
|
||||
@@ -1,117 +0,0 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: 27fb147a843f
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "27fb147a843f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -31,7 +31,6 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
@@ -130,7 +129,6 @@ from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -1622,102 +1620,6 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -282,23 +282,6 @@ def _log_and_raise_for_status(response: requests.Response) -> None:
|
||||
raise
|
||||
|
||||
|
||||
GRAPH_INVALID_REQUEST_CODE = "invalidRequest"
|
||||
|
||||
|
||||
def _is_graph_invalid_request(response: requests.Response) -> bool:
|
||||
"""Return True if the response body is the generic Graph API
|
||||
``{"error": {"code": "invalidRequest", "message": "Invalid request"}}``
|
||||
shape. This particular error has no actionable inner error code and is
|
||||
returned by the site-pages endpoint when a page has a corrupt canvas layout
|
||||
(e.g. duplicate web-part IDs — see SharePoint/sp-dev-docs#8822)."""
|
||||
try:
|
||||
body = response.json()
|
||||
except Exception:
|
||||
return False
|
||||
error = body.get("error", {})
|
||||
return error.get("code") == GRAPH_INVALID_REQUEST_CODE
|
||||
|
||||
|
||||
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
|
||||
"""Load certificate from .pfx file for MSAL authentication"""
|
||||
try:
|
||||
@@ -1269,35 +1252,19 @@ class SharepointConnector(
|
||||
site.execute_query()
|
||||
site_id = site.id
|
||||
|
||||
site_pages_base = (
|
||||
f"{self.graph_api_base}/sites/{site_id}/pages/microsoft.graph.sitePage"
|
||||
page_url: str | None = (
|
||||
f"{self.graph_api_base}/sites/{site_id}" f"/pages/microsoft.graph.sitePage"
|
||||
)
|
||||
page_url: str | None = site_pages_base
|
||||
params: dict[str, str] | None = {"$expand": "canvasLayout"}
|
||||
total_yielded = 0
|
||||
yielded_ids: set[str] = set()
|
||||
|
||||
while page_url:
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
if e.response.status_code == 404:
|
||||
logger.warning(f"Site page not found: {page_url}")
|
||||
break
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout on the LIST endpoint returned 400 "
|
||||
f"for site {site_descriptor.url}. Falling back to "
|
||||
f"per-page expansion."
|
||||
)
|
||||
yield from self._fetch_site_pages_individually(
|
||||
site_pages_base, start, end, skip_ids=yielded_ids
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
params = None # nextLink already embeds query params
|
||||
@@ -1306,98 +1273,12 @@ class SharepointConnector(
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
total_yielded += 1
|
||||
page_id = page.get("id")
|
||||
if page_id:
|
||||
yielded_ids.add(page_id)
|
||||
yield page
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}")
|
||||
|
||||
def _fetch_site_pages_individually(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
skip_ids: set[str] | None = None,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Fallback for _fetch_site_pages: list pages without $expand, then
|
||||
expand canvasLayout on each page individually.
|
||||
|
||||
The Graph API's LIST endpoint can return 400 when $expand=canvasLayout
|
||||
is used and *any* page in the site has a corrupt canvas layout (e.g.
|
||||
duplicate web part IDs — see SharePoint/sp-dev-docs#8822). Since the
|
||||
LIST expansion is all-or-nothing, a single bad page poisons the entire
|
||||
response. This method works around it by fetching metadata first, then
|
||||
expanding each page individually so only the broken page loses its
|
||||
canvas content.
|
||||
|
||||
``skip_ids`` contains page IDs already yielded by the caller before the
|
||||
fallback was triggered, preventing duplicates.
|
||||
"""
|
||||
page_url: str | None = site_pages_base
|
||||
total_yielded = 0
|
||||
_skip_ids = skip_ids or set()
|
||||
|
||||
while page_url:
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
break
|
||||
raise
|
||||
|
||||
for page in data.get("value", []):
|
||||
if not _site_page_in_time_window(page, start, end):
|
||||
continue
|
||||
|
||||
page_id = page.get("id")
|
||||
if page_id and page_id in _skip_ids:
|
||||
continue
|
||||
|
||||
if not page_id:
|
||||
total_yielded += 1
|
||||
yield page
|
||||
continue
|
||||
|
||||
expanded = self._try_expand_single_page(site_pages_base, page_id, page)
|
||||
total_yielded += 1
|
||||
yield expanded
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
|
||||
logger.debug(
|
||||
f"Yielded {total_yielded} site pages (per-page expansion fallback)"
|
||||
)
|
||||
|
||||
def _try_expand_single_page(
|
||||
self,
|
||||
site_pages_base: str,
|
||||
page_id: str,
|
||||
fallback_page: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Try to GET a single page with $expand=canvasLayout. On 400, return
|
||||
the metadata-only fallback so the page is still indexed (without canvas
|
||||
content)."""
|
||||
pages_collection = site_pages_base.removesuffix("/microsoft.graph.sitePage")
|
||||
single_url = f"{pages_collection}/{page_id}/microsoft.graph.sitePage"
|
||||
try:
|
||||
return self._graph_api_get_json(single_url, {"$expand": "canvasLayout"})
|
||||
except HTTPError as e:
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code == 400
|
||||
and _is_graph_invalid_request(e.response)
|
||||
):
|
||||
page_name = fallback_page.get("name", page_id)
|
||||
logger.warning(
|
||||
f"$expand=canvasLayout failed for page '{page_name}' "
|
||||
f"({page_id}). Indexing metadata only."
|
||||
)
|
||||
return fallback_page
|
||||
raise
|
||||
|
||||
def _acquire_token(self) -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
|
||||
@@ -304,13 +304,3 @@ class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
CONTEXTUAL_RAG = "contextual_rag"
|
||||
|
||||
|
||||
class HookPoint(str, PyEnum):
|
||||
DOCUMENT_INGESTION = "document_ingestion"
|
||||
QUERY_PROCESSING = "query_processing"
|
||||
|
||||
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
@@ -64,8 +64,6 @@ from onyx.db.enums import (
|
||||
BuildSessionStatus,
|
||||
EmbeddingPrecision,
|
||||
HierarchyNodeType,
|
||||
HookFailStrategy,
|
||||
HookPoint,
|
||||
IndexingMode,
|
||||
OpenSearchDocumentMigrationStatus,
|
||||
OpenSearchTenantMigrationStatus,
|
||||
@@ -355,11 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -3072,65 +3065,6 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
@@ -5174,94 +5108,3 @@ class CacheStore(Base):
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Hook(Base):
|
||||
"""Pairs a HookPoint with a customer-provided API endpoint.
|
||||
|
||||
At most one Hook per HookPoint can be active at a time, enforced by a
|
||||
partial unique index on (hook_point) where is_active=true AND deleted=false.
|
||||
"""
|
||||
|
||||
__tablename__ = "hook"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
hook_point: Mapped[HookPoint] = mapped_column(
|
||||
Enum(HookPoint, native_enum=False), nullable=False
|
||||
)
|
||||
endpoint_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
is_reachable: Mapped[bool | None] = mapped_column(
|
||||
Boolean, nullable=True, default=None
|
||||
) # null = never validated, true = last check passed, false = last check failed
|
||||
fail_strategy: Mapped[HookFailStrategy] = mapped_column(
|
||||
Enum(HookFailStrategy, native_enum=False),
|
||||
nullable=False,
|
||||
default=HookFailStrategy.HARD,
|
||||
server_default=HookFailStrategy.HARD.value,
|
||||
)
|
||||
timeout_seconds: Mapped[float] = mapped_column(
|
||||
Float, nullable=False, default=30.0, server_default="30.0"
|
||||
)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
creator_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
creator: Mapped["User | None"] = relationship("User", foreign_keys=[creator_id])
|
||||
execution_logs: Mapped[list["HookExecutionLog"]] = relationship(
|
||||
"HookExecutionLog", back_populates="hook", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_hook_one_active_per_point",
|
||||
"hook_point",
|
||||
unique=True,
|
||||
postgresql_where=(is_active == True) & (deleted == False), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class HookExecutionLog(Base):
|
||||
"""Records each failed hook execution for health monitoring and debugging.
|
||||
|
||||
Only failures are logged. Retention: rows older than 30 days are deleted
|
||||
by a nightly Celery task.
|
||||
"""
|
||||
|
||||
__tablename__ = "hook_execution_log"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
hook_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("hook.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
hook_point: Mapped[HookPoint] = mapped_column(
|
||||
Enum(HookPoint, native_enum=False), nullable=False
|
||||
) # denormalized for query convenience
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status_code: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
|
||||
)
|
||||
|
||||
hook: Mapped["Hook"] = relationship("Hook", back_populates="execution_logs")
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
MIN_VOICE_PLAYBACK_SPEED = 0.5
|
||||
MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For all fields, None means "don't update this field".
|
||||
"""
|
||||
values: dict[str, bool | float] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
|
||||
)
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -66,11 +66,6 @@ class OnyxErrorCode(Enum):
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Payload (413)
|
||||
# ------------------------------------------------------------------
|
||||
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -120,9 +120,6 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -501,9 +498,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -419,15 +419,12 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -442,96 +439,12 @@ async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
# Rate limit: max tokens per user per window
|
||||
WS_TOKEN_RATE_LIMIT_MAX = 10
|
||||
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
|
||||
|
||||
|
||||
class WsTokenRateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds the WS token generation rate limit."""
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
|
||||
Raises:
|
||||
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# Atomically increment and check rate limit to avoid TOCTOU races
|
||||
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_limit_key)
|
||||
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
|
||||
results = await pipe.execute()
|
||||
new_count = results[0]
|
||||
|
||||
if new_count > WS_TOKEN_RATE_LIMIT_MAX:
|
||||
# Over limit — decrement back since we won't use this slot
|
||||
await redis.decr(rate_limit_key)
|
||||
logger.warning(f"WS token rate limit exceeded for user {user_id}")
|
||||
raise WsTokenRateLimitExceeded(
|
||||
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
|
||||
# Store the actual token
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -130,7 +129,6 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -85,11 +85,6 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -169,9 +164,6 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -248,12 +240,6 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -1,322 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.models import VoiceOption
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
VOICE_PROVIDER_VALIDATION_FAILURE_MESSAGE = (
|
||||
"Connection test failed. Please verify your API key and settings."
|
||||
)
|
||||
|
||||
|
||||
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
|
||||
"""Validate and normalize provider api_base / target URI."""
|
||||
if api_base is None:
|
||||
return None
|
||||
|
||||
allow_private_network = provider_type.lower() == "azure"
|
||||
try:
|
||||
return validate_outbound_http_url(
|
||||
api_base, allow_private_network=allow_private_network
|
||||
)
|
||||
except (ValueError, SSRFException) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid target URI: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
async def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
# Validate credentials before committing - rollback on failure
|
||||
try:
|
||||
voice_provider = get_voice_provider(provider)
|
||||
await voice_provider.validate_credentials()
|
||||
except OnyxError:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Voice provider credential validation failed on save: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
VOICE_PROVIDER_VALIDATION_FAILURE_MESSAGE,
|
||||
) from e
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
async def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Test a voice provider connection by making a real API call."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
# Validate credentials with a real API call
|
||||
try:
|
||||
await provider.validate_credentials()
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Voice provider connection test failed: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
VOICE_PROVIDER_VALIDATION_FAILURE_MESSAGE,
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
@@ -1,95 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpdateSuccess(BaseModel):
|
||||
"""Simple status response for voice provider actions."""
|
||||
|
||||
status: str = "ok"
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""Voice option returned by voice providers."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
@@ -1,251 +0,0 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
# Chunk size for streaming uploads (8KB)
|
||||
UPLOAD_READ_CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
class VoiceStatusResponse(BaseModel):
|
||||
stt_enabled: bool
|
||||
tts_enabled: bool
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_voice_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceStatusResponse:
|
||||
"""Check whether STT and TTS providers are configured and ready."""
|
||||
stt_provider = fetch_default_stt_provider(db_session)
|
||||
tts_provider = fetch_default_tts_provider(db_session)
|
||||
return VoiceStatusResponse(
|
||||
stt_enabled=stt_provider is not None and stt_provider.api_key is not None,
|
||||
tts_enabled=tts_provider is not None and tts_provider.api_key is not None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
|
||||
total += len(chunk)
|
||||
if total > MAX_AUDIO_SIZE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
audio_data = b"".join(chunks)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Transcription failed. Please try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice or provider default
|
||||
final_voice = voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
Rate limited to 10 tokens per minute per user.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
await store_ws_token(token, str(user.id))
|
||||
except WsTokenRateLimitExceeded:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
"Too many token requests. Please wait before requesting another.",
|
||||
)
|
||||
return WSTokenResponse(token=token)
|
||||
@@ -1,860 +0,0 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
# WebSocket size limits to prevent memory exhaustion attacks
|
||||
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
|
||||
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
|
||||
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
|
||||
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -140,44 +140,6 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed.
|
||||
SSRFException: If URL fails SSRF checks.
|
||||
"""
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must contain a hostname")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFException("URLs with embedded credentials are not allowed.")
|
||||
|
||||
hostname = parsed.hostname.lower()
|
||||
if hostname in BLOCKED_HOSTNAMES:
|
||||
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
|
||||
|
||||
if not allow_private_network:
|
||||
_validate_and_resolve_url(normalized_url)
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
@@ -1,182 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_credentials(self) -> None:
|
||||
"""
|
||||
Validate that the provider credentials are correct by making a
|
||||
lightweight API call. Raises on failure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
@@ -1,625 +0,0 @@
|
||||
"""Azure Speech Services voice provider for STT and TTS.
|
||||
|
||||
Azure supports:
|
||||
- **STT**: Batch transcription via REST API (audio/wav POST) and real-time
|
||||
streaming via the Azure Speech SDK (push audio stream with continuous
|
||||
recognition). The SDK handles VAD natively through its recognizing/recognized
|
||||
events.
|
||||
- **TTS**: SSML-based synthesis via REST API (streaming response) and real-time
|
||||
synthesis via the Speech SDK. Text is escaped with ``xml.sax.saxutils.escape``
|
||||
and attributes with ``quoteattr`` to prevent SSML injection.
|
||||
|
||||
Both modes support Azure cloud endpoints (region-based URLs) and self-hosted
|
||||
Speech containers (custom endpoint URLs). The ``speech_region`` is validated to
|
||||
contain only ``[a-z0-9-]`` to prevent URL injection.
|
||||
|
||||
The Azure Speech SDK (``azure-cognitiveservices-speech``) is an optional C
|
||||
extension dependency — it is imported lazily inside streaming methods so the
|
||||
provider can still be instantiated and used for REST-based operations without it.
|
||||
|
||||
See https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/
|
||||
for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# SSML namespace — W3C standard for Speech Synthesis Markup Language.
|
||||
# This is a fixed W3C specification and will not change.
|
||||
SSML_NAMESPACE = "http://www.w3.org/2001/10/synthesis"
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
raw_speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.speech_region = self._validate_speech_region(raw_speech_region)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_cloud_url(uri: str | None) -> bool:
|
||||
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
|
||||
if not uri:
|
||||
return False
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return False
|
||||
return hostname.endswith(
|
||||
(
|
||||
".speech.microsoft.com",
|
||||
".api.cognitive.microsoft.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stt_tts_match = re.match(
|
||||
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
|
||||
)
|
||||
if stt_tts_match:
|
||||
return stt_tts_match.group(1)
|
||||
|
||||
api_match = re.match(
|
||||
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
|
||||
)
|
||||
if api_match:
|
||||
return api_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_speech_region(speech_region: str) -> str:
|
||||
normalized_region = speech_region.strip().lower()
|
||||
if not normalized_region:
|
||||
return ""
|
||||
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
|
||||
raise ValueError(
|
||||
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
|
||||
)
|
||||
return normalized_region
|
||||
|
||||
def _get_stt_url(self) -> str:
|
||||
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
def _get_tts_url(self) -> str:
|
||||
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
def _is_self_hosted(self) -> bool:
|
||||
"""Check if using self-hosted container vs Azure cloud."""
|
||||
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT (cloud mode)")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = self._get_stt_url()
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS (cloud mode)")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = self._get_tts_url()
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate Azure credentials by listing available voices."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required (cloud mode)")
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
if self._is_self_hosted():
|
||||
url = f"{(self.api_base or '').rstrip('/')}/cognitiveservices/voices/list"
|
||||
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status in (401, 403):
|
||||
raise RuntimeError("Invalid Azure API key.")
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Azure credential validation failed.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming transcription (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming TTS (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,876 +0,0 @@
|
||||
"""ElevenLabs voice provider for STT and TTS.
|
||||
|
||||
ElevenLabs supports:
|
||||
- **STT**: Scribe API (batch via REST, streaming via WebSocket with Scribe v2 Realtime).
|
||||
The streaming endpoint sends base64-encoded PCM16 audio chunks and receives JSON
|
||||
transcript messages (partial_transcript, committed_transcript, utterance_end).
|
||||
- **TTS**: Text-to-speech via REST streaming and WebSocket stream-input.
|
||||
The WebSocket variant accepts incremental text chunks and returns audio in order,
|
||||
enabling low-latency playback before the full text is available.
|
||||
|
||||
See https://elevenlabs.io/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Default ElevenLabs API base URL
|
||||
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
|
||||
|
||||
# Default sample rates for STT streaming
|
||||
DEFAULT_INPUT_SAMPLE_RATE = 24000 # What the browser frontend sends
|
||||
DEFAULT_TARGET_SAMPLE_RATE = 16000 # What ElevenLabs Scribe expects
|
||||
|
||||
# Default streaming TTS output format
|
||||
DEFAULT_TTS_OUTPUT_FORMAT = "mp3_44100_64"
|
||||
|
||||
# Default TTS voice settings
|
||||
DEFAULT_VOICE_STABILITY = 0.5
|
||||
DEFAULT_VOICE_SIMILARITY_BOOST = 0.75
|
||||
|
||||
# Chunk length schedule for streaming TTS (optimized for real-time playback)
|
||||
DEFAULT_CHUNK_LENGTH_SCHEDULE = [120, 160, 250, 290]
|
||||
|
||||
# Default STT streaming VAD configuration
|
||||
DEFAULT_VAD_SILENCE_THRESHOLD_SECS = 1.0
|
||||
DEFAULT_VAD_THRESHOLD = 0.4
|
||||
DEFAULT_MIN_SPEECH_DURATION_MS = 100
|
||||
DEFAULT_MIN_SILENCE_DURATION_MS = 300
|
||||
|
||||
|
||||
class ElevenLabsSTTMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs Scribe Realtime STT API."""
|
||||
|
||||
SESSION_STARTED = "session_started"
|
||||
PARTIAL_TRANSCRIPT = "partial_transcript"
|
||||
COMMITTED_TRANSCRIPT = "committed_transcript"
|
||||
UTTERANCE_END = "utterance_end"
|
||||
SESSION_ENDED = "session_ended"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ElevenLabsTTSMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs stream-input TTS API."""
|
||||
|
||||
AUDIO = "audio"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate: int = DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code: str = "en",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters.
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection.
|
||||
# These params are part of the ElevenLabs Scribe Realtime API contract:
|
||||
# https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs={DEFAULT_VAD_SILENCE_THRESHOLD_SECS}"
|
||||
f"&vad_threshold={DEFAULT_VAD_THRESHOLD}"
|
||||
f"&min_speech_duration_ms={DEFAULT_MIN_SPEECH_DURATION_MS}"
|
||||
f"&min_silence_duration_ms={DEFAULT_MIN_SILENCE_DURATION_MS}"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == ElevenLabsSTTMessageType.ERROR:
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle message types from ElevenLabs Scribe Realtime API.
|
||||
# See https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
if msg_type == ElevenLabsSTTMessageType.SESSION_STARTED:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.PARTIAL_TRANSCRIPT:
|
||||
# Interim result — updated as more audio is processed
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT:
|
||||
# Final transcript for the current utterance (VAD detected end)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.UTTERANCE_END:
|
||||
# VAD detected end of speech (may carry text or be empty)
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.SESSION_ENDED:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
api_base: str | None = None,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self.speed = speed
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming.
|
||||
# Note: API key is sent via header only (not in body to avoid log exposure).
|
||||
# See https://elevenlabs.io/docs/api-reference/text-to-speech/stream-input
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": self.speed,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": DEFAULT_CHUNK_LENGTH_SCHEDULE,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": speed,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate ElevenLabs API key.
|
||||
|
||||
Calls /v1/models as a lightweight check. ElevenLabs returns 401 for
|
||||
both truly invalid keys and valid keys with restricted scopes, so we
|
||||
inspect the response body: a "missing_permissions" status means the
|
||||
key authenticated successfully but lacks a specific scope.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required")
|
||||
|
||||
headers = {"xi-api-key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_base}/v1/models", headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
if response.status in (401, 403):
|
||||
try:
|
||||
body = await response.json()
|
||||
detail = body.get("detail", {})
|
||||
status = (
|
||||
detail.get("status", "") if isinstance(detail, dict) else ""
|
||||
)
|
||||
except Exception:
|
||||
status = ""
|
||||
# "missing_permissions" means the key is valid but
|
||||
# lacks this specific scope — that's fine.
|
||||
if status == "missing_permissions":
|
||||
return
|
||||
raise RuntimeError("Invalid ElevenLabs API key.")
|
||||
raise RuntimeError("ElevenLabs credential validation failed.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model.
|
||||
# Frontend sends PCM16 at DEFAULT_INPUT_SAMPLE_RATE (24kHz),
|
||||
# but ElevenLabs expects DEFAULT_TARGET_SAMPLE_RATE (16kHz).
|
||||
# The transcriber resamples automatically.
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate=DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code="en",
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
output_format=DEFAULT_TTS_OUTPUT_FORMAT,
|
||||
api_base=self.api_base,
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,633 +0,0 @@
|
||||
"""OpenAI voice provider for STT and TTS.
|
||||
|
||||
OpenAI supports:
|
||||
- **STT**: Whisper (batch transcription via REST) and Realtime API (streaming
|
||||
transcription via WebSocket with server-side VAD). Audio is sent as base64-encoded
|
||||
PCM16 at 24kHz mono. The Realtime API returns transcript deltas and completed
|
||||
transcription events per VAD-detected utterance.
|
||||
- **TTS**: HTTP streaming endpoint that returns audio chunks progressively.
|
||||
Supported models: tts-1 (standard) and tts-1-hd (high quality).
|
||||
|
||||
See https://platform.openai.com/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Default OpenAI API base URL
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
|
||||
|
||||
|
||||
class OpenAIRealtimeMessageType(StrEnum):
|
||||
"""Message types from OpenAI Realtime transcription API."""
|
||||
|
||||
ERROR = "error"
|
||||
SPEECH_STARTED = "input_audio_buffer.speech_started"
|
||||
SPEECH_STOPPED = "input_audio_buffer.speech_stopped"
|
||||
BUFFER_COMMITTED = "input_audio_buffer.committed"
|
||||
TRANSCRIPTION_DELTA = "conversation.item.input_audio_transcription.delta"
|
||||
TRANSCRIPTION_COMPLETED = "conversation.item.input_audio_transcription.completed"
|
||||
SESSION_CREATED = "transcription_session.created"
|
||||
SESSION_UPDATED = "transcription_session.updated"
|
||||
ITEM_CREATED = "conversation.item.created"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "whisper-1",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = f"{ws_base}/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == OpenAIRealtimeMessageType.ERROR:
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == OpenAIRealtimeMessageType.SPEECH_STARTED:
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.SPEECH_STOPPED:
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.BUFFER_COMMITTED:
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA:
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_COMPLETED:
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
OpenAIRealtimeMessageType.SESSION_CREATED,
|
||||
OpenAIRealtimeMessageType.SESSION_UPDATED,
|
||||
OpenAIRealtimeMessageType.ITEM_CREATED,
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for *new* transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
transcript_before_commit = self._accumulated_transcript
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript != transcript_before_commit:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate OpenAI API key by listing models."""
|
||||
from openai import AuthenticationError, PermissionDeniedError
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.models.list()
|
||||
except AuthenticationError:
|
||||
raise RuntimeError("Invalid OpenAI API key.")
|
||||
except PermissionDeniedError:
|
||||
raise RuntimeError("OpenAI API key does not have sufficient permissions.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -67,8 +67,6 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
|
||||
@@ -1,43 +1,33 @@
|
||||
"""Unit tests for SharepointConnector._fetch_site_pages error handling.
|
||||
"""Unit tests for SharepointConnector._fetch_site_pages 404 handling.
|
||||
|
||||
Covers 404 handling (classic sites / no modern pages) and 400
|
||||
canvasLayout fallback (corrupt pages causing $expand=canvasLayout to
|
||||
fail on the LIST endpoint).
|
||||
The Graph Pages API returns 404 for classic sites or sites without
|
||||
modern pages enabled. _fetch_site_pages should gracefully skip these
|
||||
rather than crashing the entire indexing run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.connectors.sharepoint.connector import GRAPH_INVALID_REQUEST_CODE
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
|
||||
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
|
||||
PAGES_COLLECTION = f"https://graph.microsoft.com/v1.0/sites/{FAKE_SITE_ID}/pages"
|
||||
SITE_PAGES_BASE = f"{PAGES_COLLECTION}/microsoft.graph.sitePage"
|
||||
|
||||
|
||||
def _site_descriptor() -> SiteDescriptor:
|
||||
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
|
||||
|
||||
|
||||
def _make_http_error(
|
||||
status_code: int,
|
||||
error_code: str = "itemNotFound",
|
||||
message: str = "Item not found",
|
||||
) -> HTTPError:
|
||||
body = {"error": {"code": error_code, "message": message}}
|
||||
def _make_http_error(status_code: int) -> HTTPError:
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response._content = json.dumps(body).encode()
|
||||
response.headers["Content-Type"] = "application/json"
|
||||
response._content = b'{"error":{"code":"itemNotFound","message":"Item not found"}}'
|
||||
return HTTPError(response=response)
|
||||
|
||||
|
||||
@@ -187,139 +177,3 @@ class TestFetchSitePages404:
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
|
||||
class TestFetchSitePages400Fallback:
|
||||
"""When $expand=canvasLayout on the LIST endpoint returns 400
|
||||
invalidRequest, _fetch_site_pages should fall back to listing
|
||||
without expansion, then expanding each page individually."""
|
||||
|
||||
GOOD_PAGE: dict[str, Any] = {
|
||||
"id": "good-1",
|
||||
"name": "Good.aspx",
|
||||
"title": "Good Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
BAD_PAGE: dict[str, Any] = {
|
||||
"id": "bad-1",
|
||||
"name": "Bad.aspx",
|
||||
"title": "Bad Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
GOOD_PAGE_EXPANDED: dict[str, Any] = {
|
||||
**GOOD_PAGE,
|
||||
"canvasLayout": {"horizontalSections": []},
|
||||
}
|
||||
|
||||
def test_fallback_expands_good_pages_individually(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""On 400 from the LIST expand, the connector should list without
|
||||
expand, then GET each page individually with $expand=canvasLayout."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
bad_page = self.BAD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
assert len(pages) == 2
|
||||
assert pages[0].get("canvasLayout") is not None
|
||||
assert pages[1].get("canvasLayout") is None
|
||||
assert pages[1]["id"] == "bad-1"
|
||||
|
||||
def test_mid_pagination_400_does_not_duplicate(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first paginated batch succeeds but a later nextLink
|
||||
returns 400, pages from the first batch must not be re-yielded
|
||||
by the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
good_page = self.GOOD_PAGE
|
||||
good_page_expanded = self.GOOD_PAGE_EXPANDED
|
||||
bad_page = self.BAD_PAGE
|
||||
second_page = {
|
||||
"id": "page-2",
|
||||
"name": "Second.aspx",
|
||||
"title": "Second Page",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
next_link = "https://graph.microsoft.com/v1.0/next-page-link"
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str,
|
||||
params: dict[str, str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
if url == SITE_PAGES_BASE and params == {"$expand": "canvasLayout"}:
|
||||
return {
|
||||
"value": [good_page],
|
||||
"@odata.nextLink": next_link,
|
||||
}
|
||||
if url == next_link:
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == SITE_PAGES_BASE and params is None:
|
||||
return {"value": [good_page, bad_page, second_page]}
|
||||
expand_params = {"$expand": "canvasLayout"}
|
||||
if url == f"{PAGES_COLLECTION}/good-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return good_page_expanded
|
||||
if url == f"{PAGES_COLLECTION}/bad-1/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
raise _make_http_error(
|
||||
400, GRAPH_INVALID_REQUEST_CODE, "Invalid request"
|
||||
)
|
||||
if url == f"{PAGES_COLLECTION}/page-2/microsoft.graph.sitePage":
|
||||
assert params == expand_params, f"Expected $expand params, got {params}"
|
||||
return {**second_page, "canvasLayout": {"horizontalSections": []}}
|
||||
raise AssertionError(f"Unexpected call: {url} {params}")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
ids = [p["id"] for p in pages]
|
||||
assert ids == ["good-1", "bad-1", "page-2"]
|
||||
|
||||
def test_non_invalid_request_400_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A 400 with a different error code (not invalidRequest) should
|
||||
propagate, not trigger the fallback."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(400, "badRequest", "Something else went wrong")
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
@@ -1,507 +0,0 @@
|
||||
"""Unit tests for onyx.db.voice module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import MAX_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import MIN_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
def _make_voice_provider(
|
||||
id: int = 1,
|
||||
name: str = "Test Provider",
|
||||
provider_type: str = "openai",
|
||||
is_default_stt: bool = False,
|
||||
is_default_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create a VoiceProvider instance for testing."""
|
||||
provider = VoiceProvider()
|
||||
provider.id = id
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.is_default_stt = is_default_stt
|
||||
provider.is_default_tts = is_default_tts
|
||||
provider.api_key = None
|
||||
provider.api_base = None
|
||||
provider.custom_config = None
|
||||
provider.stt_model = None
|
||||
provider.tts_model = None
|
||||
provider.default_voice = None
|
||||
return provider
|
||||
|
||||
|
||||
class TestFetchVoiceProviders:
|
||||
"""Tests for fetch_voice_providers."""
|
||||
|
||||
def test_returns_all_providers(self, mock_db_session: MagicMock) -> None:
|
||||
providers = [
|
||||
_make_voice_provider(id=1, name="Provider A"),
|
||||
_make_voice_provider(id=2, name="Provider B"),
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = providers
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == providers
|
||||
mock_db_session.scalars.assert_called_once()
|
||||
|
||||
def test_returns_empty_list_when_no_providers(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestFetchVoiceProviderById:
|
||||
"""Tests for fetch_voice_provider_by_id."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 1)
|
||||
|
||||
assert result is provider
|
||||
mock_db_session.scalar.assert_called_once()
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchDefaultProviders:
|
||||
"""Tests for fetch_default_stt_provider and fetch_default_tts_provider."""
|
||||
|
||||
def test_fetch_default_stt_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_stt_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fetch_default_tts_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_tts_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchVoiceProviderByType:
|
||||
"""Tests for fetch_voice_provider_by_type."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1, provider_type="openai")
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "openai")
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertVoiceProvider:
|
||||
"""Tests for upsert_voice_provider."""
|
||||
|
||||
def test_creates_new_provider_when_no_id(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=None,
|
||||
name="New Provider",
|
||||
provider_type="openai",
|
||||
api_key="test-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert added_obj.name == "New Provider"
|
||||
assert added_obj.provider_type == "openai"
|
||||
|
||||
def test_updates_existing_provider(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1, name="Old Name")
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Updated Name",
|
||||
provider_type="elevenlabs",
|
||||
api_key="new-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
assert existing_provider.name == "Updated Name"
|
||||
assert existing_provider.provider_type == "elevenlabs"
|
||||
|
||||
def test_raises_when_provider_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=999,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_does_not_update_api_key_when_not_changed(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
existing_provider.api_key = "original-key" # type: ignore[assignment]
|
||||
original_api_key = existing_provider.api_key
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key="new-key",
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
# api_key should remain unchanged (same object reference)
|
||||
assert existing_provider.api_key is original_api_key
|
||||
|
||||
def test_activates_stt_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_stt=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_stt is True
|
||||
|
||||
def test_activates_tts_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_tts=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_tts is True
|
||||
|
||||
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
delete_voice_provider(mock_db_session, 999)
|
||||
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSetDefaultProviders:
|
||||
"""Tests for set_default_stt_provider and set_default_tts_provider."""
|
||||
|
||||
def test_set_default_stt_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_stt is True
|
||||
|
||||
def test_set_default_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_set_default_tts_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_tts is True
|
||||
|
||||
def test_set_default_tts_provider_updates_model_when_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(
|
||||
db_session=mock_db_session, provider_id=1, tts_model="tts-1-hd"
|
||||
)
|
||||
|
||||
assert result.tts_model == "tts-1-hd"
|
||||
|
||||
def test_set_default_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDeactivateProviders:
|
||||
"""Tests for deactivate_stt_provider and deactivate_tts_provider."""
|
||||
|
||||
def test_deactivate_stt_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_stt is False
|
||||
|
||||
def test_deactivate_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_deactivate_tts_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_tts is False
|
||||
|
||||
def test_deactivate_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestUpdateUserVoiceSettings:
|
||||
"""Tests for update_user_voice_settings."""
|
||||
|
||||
def test_updates_auto_send(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_send=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_auto_playback(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_playback=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_playback_speed_within_range(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=1.5)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
def test_clamps_playback_speed_to_min(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=0.1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MIN_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_clamps_playback_speed_to_max(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=5.0)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MAX_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_updates_multiple_settings(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(
|
||||
mock_db_session,
|
||||
user_id,
|
||||
auto_send=True,
|
||||
auto_playback=False,
|
||||
playback_speed=1.25,
|
||||
)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_no_settings_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id)
|
||||
|
||||
mock_db_session.execute.assert_not_called()
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSpeedClampingLogic:
|
||||
"""Tests for the speed clamping constants and logic."""
|
||||
|
||||
def test_min_speed_constant(self) -> None:
|
||||
assert MIN_VOICE_PLAYBACK_SPEED == 0.5
|
||||
|
||||
def test_max_speed_constant(self) -> None:
|
||||
assert MAX_VOICE_PLAYBACK_SPEED == 2.0
|
||||
|
||||
def test_clamping_formula(self) -> None:
|
||||
"""Verify the clamping formula used in update_user_voice_settings."""
|
||||
test_cases = [
|
||||
(0.1, MIN_VOICE_PLAYBACK_SPEED),
|
||||
(0.5, 0.5),
|
||||
(1.0, 1.0),
|
||||
(1.5, 1.5),
|
||||
(2.0, 2.0),
|
||||
(3.0, MAX_VOICE_PLAYBACK_SPEED),
|
||||
]
|
||||
for speed, expected in test_cases:
|
||||
clamped = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, speed)
|
||||
)
|
||||
assert (
|
||||
clamped == expected
|
||||
), f"speed={speed} expected={expected} got={clamped}"
|
||||
@@ -1,23 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.api import _validate_voice_api_base
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_allows_private_for_azure() -> None:
|
||||
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
|
||||
assert validated == "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("azure", "http://metadata.google.internal/")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_returns_none_for_none() -> None:
|
||||
assert _validate_voice_api_base("openai", None) is None
|
||||
@@ -14,7 +14,6 @@ from onyx.utils.url import _is_ip_private_or_reserved
|
||||
from onyx.utils.url import _validate_and_resolve_url
|
||||
from onyx.utils.url import ssrf_safe_get
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
|
||||
class TestIsIpPrivateOrReserved:
|
||||
@@ -306,22 +305,3 @@ class TestSsrfSafeGet:
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
|
||||
class TestValidateOutboundHttpUrl:
|
||||
def test_rejects_private_ip_by_default(self) -> None:
|
||||
with pytest.raises(SSRFException, match="internal/private IP"):
|
||||
validate_outbound_http_url("http://127.0.0.1:8000")
|
||||
|
||||
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
|
||||
validated_url = validate_outbound_http_url(
|
||||
"http://127.0.0.1:8000", allow_private_network=True
|
||||
)
|
||||
assert validated_url == "http://127.0.0.1:8000"
|
||||
|
||||
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
|
||||
with pytest.raises(SSRFException, match="not allowed"):
|
||||
validate_outbound_http_url(
|
||||
"http://metadata.google.internal/latest",
|
||||
allow_private_network=True,
|
||||
)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
def test_azure_provider_extracts_region_from_target_uri() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://westus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider.speech_region == "westus"
|
||||
|
||||
|
||||
def test_azure_provider_normalizes_uppercase_region() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "WestUS2"},
|
||||
)
|
||||
assert provider.speech_region == "westus2"
|
||||
|
||||
|
||||
def test_azure_provider_rejects_invalid_speech_region() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "westus/../../etc"},
|
||||
)
|
||||
@@ -1,194 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
# --- _is_azure_cloud_url ---
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_speech_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_cognitive_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://westus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_custom_host() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url("https://my-custom-host.com/")
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_none() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url(None)
|
||||
|
||||
|
||||
# --- _extract_speech_region_from_uri ---
|
||||
|
||||
|
||||
def test_extract_region_from_tts_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_from_cognitive_api_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_custom_domain() -> None:
|
||||
"""Custom domains use resource name, not region — must use speech_region config."""
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://myresource.cognitiveservices.azure.com/"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_none() -> None:
|
||||
assert AzureVoiceProvider._extract_speech_region_from_uri(None) is None
|
||||
|
||||
|
||||
# --- _validate_speech_region ---
|
||||
|
||||
|
||||
def test_validate_region_normalizes_to_lowercase() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("WestUS2") == "westus2"
|
||||
|
||||
|
||||
def test_validate_region_accepts_hyphens() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("us-east-1") == "us-east-1"
|
||||
|
||||
|
||||
def test_validate_region_rejects_path_traversal() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("westus/../../etc")
|
||||
|
||||
|
||||
def test_validate_region_rejects_dots() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("west.us")
|
||||
|
||||
|
||||
# --- _pcm16_to_wav ---
|
||||
|
||||
|
||||
def test_pcm16_to_wav_produces_valid_wav() -> None:
|
||||
samples = [32767, -32768, 0, 1234]
|
||||
pcm_data = struct.pack(f"<{len(samples)}h", *samples)
|
||||
wav_bytes = AzureVoiceProvider._pcm16_to_wav(pcm_data, sample_rate=16000)
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 16000
|
||||
frames = wav_file.readframes(4)
|
||||
recovered = struct.unpack(f"<{len(samples)}h", frames)
|
||||
assert list(recovered) == samples
|
||||
|
||||
|
||||
# --- URL Construction ---
|
||||
|
||||
|
||||
def test_get_tts_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "eastus"}
|
||||
)
|
||||
assert (
|
||||
provider._get_tts_url()
|
||||
== "https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_get_stt_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "westus2"}
|
||||
)
|
||||
assert "westus2.stt.speech.microsoft.com" in provider._get_stt_url()
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted_strips_trailing_slash() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000/", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
# --- _is_self_hosted ---
|
||||
|
||||
|
||||
def test_is_self_hosted_true_for_custom_endpoint() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._is_self_hosted() is True
|
||||
|
||||
|
||||
def test_is_self_hosted_false_for_azure_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://eastus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider._is_self_hosted() is False
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
assert len(result) // 2 == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_empty_data() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
assert t._resample_pcm16(b"") == b""
|
||||
@@ -1,117 +0,0 @@
|
||||
import struct
|
||||
|
||||
from onyx.voice.providers.elevenlabs import _http_to_ws_url
|
||||
from onyx.voice.providers.elevenlabs import DEFAULT_ELEVENLABS_API_BASE
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsSTTMessageType
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.elevenlabs.io") == "wss://api.elevenlabs.io"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:8080") == "ws://localhost:8080"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_other_schemes() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
def test_http_to_ws_url_preserves_path() -> None:
|
||||
assert (
|
||||
_http_to_ws_url("https://api.elevenlabs.io/v1/tts")
|
||||
== "wss://api.elevenlabs.io/v1/tts"
|
||||
)
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_stt_message_type_compares_as_string() -> None:
|
||||
"""StrEnum members should work in string comparisons (e.g. from JSON)."""
|
||||
assert str(ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT) == "committed_transcript"
|
||||
assert isinstance(ElevenLabsSTTMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough_when_same_rate() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
"""24kHz -> 16kHz should produce fewer samples (ratio 3:2)."""
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
|
||||
assert len(output_samples) == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_clamps_to_int16_range() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [32767, -32768, 32767, -32768, 32767, -32768]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
for s in output_samples:
|
||||
assert -32768 <= s <= 32767
|
||||
|
||||
|
||||
# --- Provider Model Defaulting ---
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_stt_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", stt_model="invalid_model")
|
||||
assert provider.stt_model == "scribe_v1"
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_tts_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", tts_model="invalid_model")
|
||||
assert provider.tts_model == "eleven_multilingual_v2"
|
||||
|
||||
|
||||
def test_provider_accepts_valid_models() -> None:
|
||||
provider = ElevenLabsVoiceProvider(
|
||||
api_key="test", stt_model="scribe_v2_realtime", tts_model="eleven_turbo_v2_5"
|
||||
)
|
||||
assert provider.stt_model == "scribe_v2_realtime"
|
||||
assert provider.tts_model == "eleven_turbo_v2_5"
|
||||
|
||||
|
||||
def test_provider_defaults_api_base() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
assert provider.api_base == DEFAULT_ELEVENLABS_API_BASE
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -1,97 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from onyx.voice.providers.openai import _create_wav_header
|
||||
from onyx.voice.providers.openai import _http_to_ws_url
|
||||
from onyx.voice.providers.openai import OpenAIRealtimeMessageType
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.openai.com") == "wss://api.openai.com"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:9090") == "ws://localhost:9090"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_ws() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_realtime_message_type_compares_as_string() -> None:
|
||||
assert str(OpenAIRealtimeMessageType.ERROR) == "error"
|
||||
assert (
|
||||
str(OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA)
|
||||
== "conversation.item.input_audio_transcription.delta"
|
||||
)
|
||||
assert isinstance(OpenAIRealtimeMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- _create_wav_header ---
|
||||
|
||||
|
||||
def test_wav_header_is_44_bytes() -> None:
|
||||
assert len(_create_wav_header(1000)) == 44
|
||||
|
||||
|
||||
def test_wav_header_chunk_size_matches_data_length() -> None:
|
||||
data_length = 2000
|
||||
header = _create_wav_header(data_length)
|
||||
chunk_size = struct.unpack_from("<I", header, 4)[0]
|
||||
assert chunk_size == 36 + data_length
|
||||
|
||||
|
||||
def test_wav_header_byte_rate() -> None:
|
||||
header = _create_wav_header(100, sample_rate=24000, channels=1, bits_per_sample=16)
|
||||
byte_rate = struct.unpack_from("<I", header, 28)[0]
|
||||
assert byte_rate == 24000 * 1 * 16 // 8
|
||||
|
||||
|
||||
def test_wav_header_produces_valid_wav() -> None:
|
||||
"""Header + PCM data should parse as valid WAV."""
|
||||
data_length = 100
|
||||
pcm_data = b"\x00" * data_length
|
||||
header = _create_wav_header(data_length, sample_rate=24000)
|
||||
|
||||
with wave.open(io.BytesIO(header + pcm_data), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 24000
|
||||
assert wav_file.getnframes() == data_length // 2
|
||||
|
||||
|
||||
# --- Provider Defaults ---
|
||||
|
||||
|
||||
def test_provider_default_models() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
assert provider.stt_model == "whisper-1"
|
||||
assert provider.tts_model == "tts-1"
|
||||
assert provider.default_voice == "alloy"
|
||||
|
||||
|
||||
def test_provider_custom_models() -> None:
|
||||
provider = OpenAIVoiceProvider(
|
||||
api_key="test",
|
||||
stt_model="gpt-4o-transcribe",
|
||||
tts_model="tts-1-hd",
|
||||
default_voice="nova",
|
||||
)
|
||||
assert provider.stt_model == "gpt-4o-transcribe"
|
||||
assert provider.tts_model == "tts-1-hd"
|
||||
assert provider.default_voice == "nova"
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.35
|
||||
version: 0.4.34
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -1184,9 +1184,8 @@ auth:
|
||||
opensearch_admin_username: "admin"
|
||||
opensearch_admin_password: "OnyxDev1!"
|
||||
userauth:
|
||||
# -- Used for password reset / verification tokens and OAuth/OIDC state signing.
|
||||
# Disabled by default to preserve upgrade compatibility for existing Helm customers.
|
||||
enabled: false
|
||||
# -- Used for signing password reset tokens, email verification tokens, and JWT tokens.
|
||||
enabled: true
|
||||
# -- Overwrite the default secret name, ignored if existingSecret is defined
|
||||
secretName: 'onyx-userauth'
|
||||
# -- Use a secret specified elsewhere
|
||||
@@ -1194,16 +1193,15 @@ auth:
|
||||
# -- This defines the env var to secret map
|
||||
secretKeys:
|
||||
USER_AUTH_SECRET: user_auth_secret
|
||||
# -- Secret value. Required when this secret is enabled - generate with: openssl rand -hex 32
|
||||
# If not set, helm install/upgrade will fail when auth.userauth.enabled=true.
|
||||
# -- Secret value. Required - generate with: openssl rand -hex 32
|
||||
# If not set, helm install/upgrade will fail.
|
||||
values:
|
||||
user_auth_secret: ""
|
||||
|
||||
configMap:
|
||||
# Auth type: "basic" (default), "google_oauth", "oidc", or "saml"
|
||||
# UPGRADE NOTE: Default changed from "disabled" to "basic" in 0.4.34.
|
||||
# Set auth.userauth.enabled=true and provide auth.userauth.values.user_auth_secret
|
||||
# before enabling flows that require it.
|
||||
# You must also set auth.userauth.values.user_auth_secret.
|
||||
AUTH_TYPE: "basic"
|
||||
# Enable PKCE for OIDC login flow. Leave empty/false for backward compatibility.
|
||||
OIDC_PKCE_ENABLED: ""
|
||||
|
||||
@@ -35,7 +35,6 @@ backend = [
|
||||
"alembic==1.10.4",
|
||||
"asyncpg==0.30.0",
|
||||
"atlassian-python-api==3.41.16",
|
||||
"azure-cognitiveservices-speech==1.38.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"boto3==1.39.11",
|
||||
"boto3-stubs[s3]==1.39.11",
|
||||
|
||||
15
uv.lock
generated
15
uv.lock
generated
@@ -463,19 +463,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-cognitiveservices-speech"
|
||||
version = "1.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/f4/4571c42cb00f8af317d5431f594b4ece1fbe59ab59f106947fea8e90cf89/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:18dce915ab032711f687abb3297dd19176b9cbea562b322ee6fa7365ef4a5091", size = 6775838, upload-time = "2024-06-11T03:08:35.202Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/22/0ca2c59a573119950cad1f53531fec9872fc38810c405a4e1827f3d13a8e/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9dd0800fbc4a8438c6dfd5747a658251914fe2d205a29e9b46158cadac6ab381", size = 6687975, upload-time = "2024-06-11T03:08:38.797Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/96/5436c09de3af3a9aefaa8cc00533c3a0f5d17aef5bbe017c17f0a30ad66e/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:1c344e8a6faadb063cea451f0301e13b44d9724e1242337039bff601e81e6f86", size = 40022287, upload-time = "2024-06-11T03:08:16.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/2d/ba20d05ff77ec9870cd489e6e7a474ba7fe820524bcf6fd202025e0c11cf/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1e002595a749471efeac3a54c80097946570b76c13049760b97a4b881d9d24af", size = 39788653, upload-time = "2024-06-11T03:08:30.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/21/25f8c37fb6868db4346ca977c287ede9e87f609885d932653243c9ed5f63/azure_cognitiveservices_speech-1.38.0-py3-none-win32.whl", hash = "sha256:16a530e6c646eb49ea0bc05cb45a9d28b99e4b67613f6c3a6c54e26e6bf65241", size = 1428364, upload-time = "2024-06-11T03:08:03.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/05/a6414a3481c5ee30c4f32742abe055e5f3ce4ff69e936089d86ece354067/azure_cognitiveservices_speech-1.38.0-py3-none-win_amd64.whl", hash = "sha256:1d38d8c056fb3f513a9ff27ab4e77fd08ca487f8788cc7a6df772c1ab2c97b54", size = 1539297, upload-time = "2024-06-11T03:08:01.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
@@ -4240,7 +4227,6 @@ backend = [
|
||||
{ name = "asana" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "atlassian-python-api" },
|
||||
{ name = "azure-cognitiveservices-speech" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "boto3" },
|
||||
{ name = "boto3-stubs", extra = ["s3"] },
|
||||
@@ -4395,7 +4381,6 @@ requires-dist = [
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
import type { SizeVariant, WidthVariant } from "@opal/shared";
|
||||
import type { InteractiveContainerRoundingVariant } from "@opal/core";
|
||||
import type { TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent, IconProps } from "@opal/types";
|
||||
import { SvgChevronDownSmall } from "@opal/icons";
|
||||
@@ -80,6 +81,9 @@ type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Override the default rounding derived from `size`. */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -95,6 +99,7 @@ function OpenButton({
|
||||
justifyContent,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
roundingVariant: roundingVariantOverride,
|
||||
interaction,
|
||||
variant = "select-heavy",
|
||||
...statefulProps
|
||||
@@ -132,7 +137,8 @@ function OpenButton({
|
||||
heightVariant={size}
|
||||
widthVariant={width}
|
||||
roundingVariant={
|
||||
isLarge ? "default" : size === "2xs" ? "mini" : "compact"
|
||||
roundingVariantOverride ??
|
||||
(isLarge ? "default" : size === "2xs" ? "mini" : "compact")
|
||||
}
|
||||
>
|
||||
<div
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAudio = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 10V6M5 14V2M11 11V5M14 9V7M8 10V6"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAudio;
|
||||
@@ -17,7 +17,6 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
|
||||
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
|
||||
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudio } from "@opal/icons/audio";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
@@ -108,8 +107,6 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
|
||||
export { default as SvgMcp } from "@opal/icons/mcp";
|
||||
export { default as SvgMenu } from "@opal/icons/menu";
|
||||
export { default as SvgMicrophone } from "@opal/icons/microphone";
|
||||
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
|
||||
export { default as SvgMinus } from "@opal/icons/minus";
|
||||
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
@@ -180,8 +177,6 @@ export { default as SvgUserManage } from "@opal/icons/user-manage";
|
||||
export { default as SvgUserPlus } from "@opal/icons/user-plus";
|
||||
export { default as SvgUserSync } from "@opal/icons/user-sync";
|
||||
export { default as SvgUsers } from "@opal/icons/users";
|
||||
export { default as SvgVolume } from "@opal/icons/volume";
|
||||
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
|
||||
export { default as SvgWallet } from "@opal/icons/wallet";
|
||||
export { default as SvgWorkflow } from "@opal/icons/workflow";
|
||||
export { default as SvgX } from "@opal/icons/x";
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
{/* Microphone body */}
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
{/* Diagonal slash */}
|
||||
<path
|
||||
d="M2 2L14 14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophoneOff;
|
||||
@@ -1,21 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophone = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophone;
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M14 6L11 9M11 6L14 9"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolumeOff;
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolume = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolume;
|
||||
@@ -59,7 +59,7 @@ const nextConfig = {
|
||||
{
|
||||
key: "Permissions-Policy",
|
||||
value:
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
6
web/package-lock.json
generated
6
web/package-lock.json
generated
@@ -53,7 +53,7 @@
|
||||
"formik": "^2.2.9",
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-cookie": "^3.0.5",
|
||||
"katex": "^0.16.38",
|
||||
"katex": "^0.16.17",
|
||||
"linguist-languages": "^9.3.1",
|
||||
"lodash": "^4.17.23",
|
||||
"lowlight": "^3.3.0",
|
||||
@@ -12794,9 +12794,7 @@
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.38",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.38.tgz",
|
||||
"integrity": "sha512-cjHooZUmIAUmDsHBN+1n8LaZdpmbj03LtYeYPyuYB7OuloiaeaV6N4LcfjcnHVzGWjVQmKrxxTrpDcmSzEZQwQ==",
|
||||
"version": "0.16.25",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
|
||||
@@ -71,7 +71,7 @@
|
||||
"formik": "^2.2.9",
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-cookie": "^3.0.5",
|
||||
"katex": "^0.16.38",
|
||||
"katex": "^0.16.17",
|
||||
"linguist-languages": "^9.3.1",
|
||||
"lodash": "^4.17.23",
|
||||
"lowlight": "^3.3.0",
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="currentColor"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 206 B |
@@ -1,4 +0,0 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="white"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="white"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 192 B |
@@ -1,558 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { FunctionComponent, useState, useEffect } from "react";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
|
||||
import { FormField } from "@/refresh-components/form/FormField";
|
||||
import { Vertical, Horizontal } from "@/layouts/input-layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { SvgArrowExchange, SvgOnyxLogo } from "@opal/icons";
|
||||
import { Disabled } from "@opal/core";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import { VoiceProviderView } from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
testVoiceProvider,
|
||||
upsertVoiceProvider,
|
||||
fetchVoicesByType,
|
||||
fetchLLMProviders,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
|
||||
interface VoiceOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface LLMProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider: string;
|
||||
api_key: string | null;
|
||||
}
|
||||
|
||||
interface ApiKeyOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
}
|
||||
|
||||
interface VoiceProviderSetupModalProps {
|
||||
providerType: string;
|
||||
existingProvider: VoiceProviderView | null;
|
||||
mode: "stt" | "tts";
|
||||
defaultModelId?: string | null;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
const PROVIDER_LABELS: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure Speech Services",
|
||||
elevenlabs: "ElevenLabs",
|
||||
};
|
||||
|
||||
const PROVIDER_API_KEY_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/api-keys",
|
||||
azure: "https://portal.azure.com/",
|
||||
elevenlabs: "https://elevenlabs.io/app/settings/api-keys",
|
||||
};
|
||||
|
||||
const PROVIDER_LOGO_URLS: Record<string, string> = {
|
||||
openai: "/Openai.svg",
|
||||
azure: "/Azure.png",
|
||||
elevenlabs: "/ElevenLabs.svg",
|
||||
};
|
||||
|
||||
const PROVIDER_DOCS_URLS: Record<string, string> = {
|
||||
openai: "https://platform.openai.com/docs/guides/text-to-speech",
|
||||
azure: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/",
|
||||
elevenlabs: "https://elevenlabs.io/docs",
|
||||
};
|
||||
|
||||
const PROVIDER_VOICE_DOCS_URLS: Record<string, { url: string; label: string }> =
|
||||
{
|
||||
openai: {
|
||||
url: "https://platform.openai.com/docs/guides/text-to-speech#voice-options",
|
||||
label: "OpenAI",
|
||||
},
|
||||
azure: {
|
||||
url: "https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts",
|
||||
label: "Azure",
|
||||
},
|
||||
elevenlabs: {
|
||||
url: "https://elevenlabs.io/docs/voices/premade-voices",
|
||||
label: "ElevenLabs",
|
||||
},
|
||||
};
|
||||
|
||||
const OPENAI_STT_MODELS = [{ id: "whisper-1", name: "Whisper v1" }];
|
||||
|
||||
const OPENAI_TTS_MODELS = [
|
||||
{ id: "tts-1", name: "TTS-1" },
|
||||
{ id: "tts-1-hd", name: "TTS-1 HD" },
|
||||
];
|
||||
|
||||
// Map model IDs from cards to actual API model IDs
|
||||
const MODEL_ID_MAP: Record<string, string> = {
|
||||
"tts-1": "tts-1",
|
||||
"tts-1-hd": "tts-1-hd",
|
||||
whisper: "whisper-1",
|
||||
};
|
||||
|
||||
type Phase = "idle" | "validating" | "saving";
|
||||
type MessageState = {
|
||||
kind: "status" | "error" | "success";
|
||||
text: string;
|
||||
} | null;
|
||||
|
||||
export default function VoiceProviderSetupModal({
|
||||
providerType,
|
||||
existingProvider,
|
||||
mode,
|
||||
defaultModelId,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: VoiceProviderSetupModalProps) {
|
||||
// Map the card model ID to the actual API model ID
|
||||
// Prioritize defaultModelId (from the clicked card) over stored value
|
||||
const initialTtsModel = defaultModelId
|
||||
? MODEL_ID_MAP[defaultModelId] ?? "tts-1"
|
||||
: existingProvider?.tts_model ?? "tts-1";
|
||||
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiKeyChanged, setApiKeyChanged] = useState(false);
|
||||
const [targetUri, setTargetUri] = useState(
|
||||
existingProvider?.target_uri ?? ""
|
||||
);
|
||||
const [selectedLlmProviderId, setSelectedLlmProviderId] = useState<
|
||||
number | null
|
||||
>(null);
|
||||
const [sttModel, setSttModel] = useState(
|
||||
existingProvider?.stt_model ?? "whisper-1"
|
||||
);
|
||||
const [ttsModel, setTtsModel] = useState(initialTtsModel);
|
||||
const [defaultVoice, setDefaultVoice] = useState(
|
||||
existingProvider?.default_voice ?? ""
|
||||
);
|
||||
const [phase, setPhase] = useState<Phase>("idle");
|
||||
const [message, setMessage] = useState<MessageState>(null);
|
||||
|
||||
// Dynamic voices fetched from backend
|
||||
const [voiceOptions, setVoiceOptions] = useState<VoiceOption[]>([]);
|
||||
const [isLoadingVoices, setIsLoadingVoices] = useState(false);
|
||||
|
||||
// Existing OpenAI LLM providers for API key reuse
|
||||
const [existingApiKeyOptions, setExistingApiKeyOptions] = useState<
|
||||
ApiKeyOption[]
|
||||
>([]);
|
||||
const [llmProviderMap, setLlmProviderMap] = useState<Map<string, number>>(
|
||||
new Map()
|
||||
);
|
||||
|
||||
// Fetch existing OpenAI LLM providers (for API key reuse)
|
||||
useEffect(() => {
|
||||
if (providerType !== "openai") return;
|
||||
|
||||
fetchLLMProviders()
|
||||
.then((res) => res.json())
|
||||
.then((data: { providers: LLMProviderView[] } | LLMProviderView[]) => {
|
||||
const providers = Array.isArray(data) ? data : data.providers ?? [];
|
||||
const openaiProviders = providers.filter(
|
||||
(p) => p.provider === "openai" && p.api_key
|
||||
);
|
||||
const options: ApiKeyOption[] = openaiProviders.map((p) => ({
|
||||
value: p.api_key!,
|
||||
label: p.api_key!,
|
||||
description: `Used for LLM provider **${p.name}**`,
|
||||
}));
|
||||
setExistingApiKeyOptions(options);
|
||||
|
||||
// Map masked API keys to provider IDs for lookup on selection
|
||||
const providerMap = new Map<string, number>();
|
||||
openaiProviders.forEach((p) => {
|
||||
if (p.api_key) {
|
||||
providerMap.set(p.api_key, p.id);
|
||||
}
|
||||
});
|
||||
setLlmProviderMap(providerMap);
|
||||
})
|
||||
.catch(() => {
|
||||
setExistingApiKeyOptions([]);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
// Fetch voices on mount (works without API key for ElevenLabs/OpenAI)
|
||||
useEffect(() => {
|
||||
setIsLoadingVoices(true);
|
||||
fetchVoicesByType(providerType)
|
||||
.then((res) => res.json())
|
||||
.then((data: Array<{ id: string; name: string }>) => {
|
||||
const options = data.map((v) => ({
|
||||
value: v.id,
|
||||
label: v.name,
|
||||
description: v.id,
|
||||
}));
|
||||
setVoiceOptions(options);
|
||||
// Set default voice to first option if not already set,
|
||||
// or if current value doesn't exist in the new options
|
||||
setDefaultVoice((prev) => {
|
||||
if (!prev) return options[0]?.value ?? "";
|
||||
const existsInOptions = options.some((opt) => opt.value === prev);
|
||||
return existsInOptions ? prev : options[0]?.value ?? "";
|
||||
});
|
||||
})
|
||||
.catch(() => {
|
||||
setVoiceOptions([]);
|
||||
})
|
||||
.finally(() => {
|
||||
setIsLoadingVoices(false);
|
||||
});
|
||||
}, [providerType]);
|
||||
|
||||
const isEditing = !!existingProvider;
|
||||
const label = PROVIDER_LABELS[providerType] ?? providerType;
|
||||
const isProcessing = phase !== "idle";
|
||||
const hasNonEmptyApiKey = apiKey.trim().length > 0;
|
||||
const shouldSendApiKey =
|
||||
!selectedLlmProviderId && apiKeyChanged && hasNonEmptyApiKey;
|
||||
const shouldUseStoredKey =
|
||||
isEditing && !selectedLlmProviderId && !shouldSendApiKey;
|
||||
|
||||
const canConnect = (() => {
|
||||
if (selectedLlmProviderId) return true;
|
||||
if (!isEditing && !apiKey) return false;
|
||||
if (providerType === "azure" && !isEditing && !targetUri) return false;
|
||||
return true;
|
||||
})();
|
||||
|
||||
// Logo arrangement component for the modal header
|
||||
// No useMemo needed - providerType and label are stable props
|
||||
const LogoArrangement: FunctionComponent<IconProps> = () => (
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="flex items-center justify-center size-7 shrink-0 overflow-clip">
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={24} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={24} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={24} />
|
||||
) : (
|
||||
<Image
|
||||
src={PROVIDER_LOGO_URLS[providerType] ?? "/Openai.svg"}
|
||||
alt={`${label} logo`}
|
||||
width={24}
|
||||
height={24}
|
||||
className="object-contain"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-4 shrink-0">
|
||||
<SvgArrowExchange className="size-3 text-text-04" />
|
||||
</div>
|
||||
<div className="flex items-center justify-center size-7 p-0.5 shrink-0 overflow-clip">
|
||||
<SvgOnyxLogo size={24} className="text-text-04 shrink-0" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
const formFieldState: "idle" | "error" | "success" =
|
||||
message?.kind === "error"
|
||||
? "error"
|
||||
: message?.kind === "success"
|
||||
? "success"
|
||||
: "idle";
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!canConnect) return;
|
||||
|
||||
setMessage(null);
|
||||
|
||||
try {
|
||||
// Test the connection first (skip if reusing LLM provider key - validated on save)
|
||||
if (!selectedLlmProviderId) {
|
||||
setPhase("validating");
|
||||
setMessage({ kind: "status", text: "Validating API key..." });
|
||||
|
||||
const testResponse = await testVoiceProvider({
|
||||
provider_type: providerType,
|
||||
api_key: shouldSendApiKey ? apiKey : undefined,
|
||||
target_uri: targetUri || undefined,
|
||||
use_stored_key: shouldUseStoredKey,
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
const data = await testResponse.json().catch(() => ({}));
|
||||
const detail =
|
||||
typeof data?.detail === "string"
|
||||
? data.detail
|
||||
: "Connection test failed";
|
||||
setPhase("idle");
|
||||
setMessage({ kind: "error", text: detail });
|
||||
return;
|
||||
}
|
||||
|
||||
setMessage({
|
||||
kind: "status",
|
||||
text: "API key validated. Saving provider...",
|
||||
});
|
||||
}
|
||||
|
||||
// Save the provider
|
||||
setPhase("saving");
|
||||
const response = await upsertVoiceProvider({
|
||||
id: existingProvider?.id,
|
||||
name: label,
|
||||
provider_type: providerType,
|
||||
api_key: shouldSendApiKey ? apiKey : undefined,
|
||||
api_key_changed: shouldSendApiKey,
|
||||
target_uri: targetUri || undefined,
|
||||
llm_provider_id: selectedLlmProviderId,
|
||||
stt_model: sttModel,
|
||||
tts_model: ttsModel,
|
||||
default_voice: defaultVoice,
|
||||
activate_stt: mode === "stt",
|
||||
activate_tts: mode === "tts",
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
onSuccess();
|
||||
} else {
|
||||
const data = await response.json().catch(() => ({}));
|
||||
const detail =
|
||||
typeof data?.detail === "string"
|
||||
? data.detail
|
||||
: "Failed to save provider";
|
||||
setPhase("idle");
|
||||
setMessage({ kind: "error", text: detail });
|
||||
}
|
||||
} catch {
|
||||
setPhase("idle");
|
||||
setMessage({ kind: "error", text: "Failed to save provider" });
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={LogoArrangement}
|
||||
title={isEditing ? `Edit ${label}` : `Set up ${label}`}
|
||||
description={`Connect to ${label} and set up your voice models.`}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Section gap={1} alignItems="stretch">
|
||||
<FormField name="api_key" state={formFieldState} className="w-full">
|
||||
<FormField.Label>API Key</FormField.Label>
|
||||
<FormField.Description>
|
||||
{isEditing ? (
|
||||
"Leave blank to keep existing key"
|
||||
) : (
|
||||
<>
|
||||
Paste your{" "}
|
||||
<a
|
||||
href={PROVIDER_API_KEY_URLS[providerType]}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
API key
|
||||
</a>{" "}
|
||||
from {label} to access your models.
|
||||
</>
|
||||
)}
|
||||
</FormField.Description>
|
||||
<FormField.Control asChild>
|
||||
{providerType === "openai" &&
|
||||
existingApiKeyOptions.length > 0 ? (
|
||||
<InputComboBox
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
setSelectedLlmProviderId(null);
|
||||
setMessage(null);
|
||||
}}
|
||||
onValueChange={(value) => {
|
||||
setApiKey(value);
|
||||
// Check if this is an existing key
|
||||
const llmProviderId = llmProviderMap.get(value);
|
||||
if (llmProviderId) {
|
||||
setSelectedLlmProviderId(llmProviderId);
|
||||
setApiKeyChanged(false);
|
||||
} else {
|
||||
setSelectedLlmProviderId(null);
|
||||
setApiKeyChanged(true);
|
||||
}
|
||||
setMessage(null);
|
||||
}}
|
||||
options={existingApiKeyOptions}
|
||||
separatorLabel="Reuse OpenAI API Keys"
|
||||
strict={false}
|
||||
showAddPrefix
|
||||
/>
|
||||
) : (
|
||||
<PasswordInputTypeIn
|
||||
placeholder={isEditing ? "••••••••" : "Enter API key"}
|
||||
value={apiKey}
|
||||
onChange={(e) => {
|
||||
setApiKey(e.target.value);
|
||||
setApiKeyChanged(true);
|
||||
setMessage(null);
|
||||
}}
|
||||
showClearButton={false}
|
||||
/>
|
||||
)}
|
||||
</FormField.Control>
|
||||
{isProcessing ? (
|
||||
<FormField.APIMessage
|
||||
state="loading"
|
||||
messages={{
|
||||
loading: message?.text ?? "Validating API key...",
|
||||
}}
|
||||
/>
|
||||
) : message ? (
|
||||
<FormField.Message
|
||||
messages={{
|
||||
idle: "",
|
||||
error: message.kind === "error" ? message.text : "",
|
||||
success: message.kind === "success" ? message.text : "",
|
||||
}}
|
||||
/>
|
||||
) : null}
|
||||
</FormField>
|
||||
|
||||
{providerType === "azure" && (
|
||||
<Vertical
|
||||
title="Target URI"
|
||||
subDescription={
|
||||
<>
|
||||
Paste the endpoint shown in{" "}
|
||||
<a
|
||||
href="https://portal.azure.com/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
Azure Portal (Keys and Endpoint)
|
||||
</a>
|
||||
. Onyx extracts the speech region from this URL. Examples:
|
||||
https://westus.api.cognitive.microsoft.com/ or
|
||||
https://westus.tts.speech.microsoft.com/.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputTypeIn
|
||||
placeholder={
|
||||
isEditing
|
||||
? "Leave blank to keep existing"
|
||||
: "https://<region>.api.cognitive.microsoft.com/"
|
||||
}
|
||||
value={targetUri}
|
||||
onChange={(e) => setTargetUri(e.target.value)}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "stt" && (
|
||||
<Horizontal title="STT Model" center nonInteractive>
|
||||
<InputSelect value={sttModel} onValueChange={setSttModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_STT_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Horizontal>
|
||||
)}
|
||||
|
||||
{providerType === "openai" && mode === "tts" && (
|
||||
<Vertical
|
||||
title="Default Model"
|
||||
subDescription="This model will be used by Onyx by default for text-to-speech."
|
||||
nonInteractive
|
||||
>
|
||||
<InputSelect value={ttsModel} onValueChange={setTtsModel}>
|
||||
<InputSelect.Trigger />
|
||||
<InputSelect.Content>
|
||||
{OPENAI_TTS_MODELS.map((model) => (
|
||||
<InputSelect.Item key={model.id} value={model.id}>
|
||||
{model.name}
|
||||
</InputSelect.Item>
|
||||
))}
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Vertical>
|
||||
)}
|
||||
|
||||
{mode === "tts" && (
|
||||
<Vertical
|
||||
title="Voice"
|
||||
subDescription={
|
||||
<>
|
||||
This voice will be used for spoken responses. See full list
|
||||
of supported languages and voices at{" "}
|
||||
<a
|
||||
href={
|
||||
PROVIDER_VOICE_DOCS_URLS[providerType]?.url ??
|
||||
PROVIDER_DOCS_URLS[providerType]
|
||||
}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
{PROVIDER_VOICE_DOCS_URLS[providerType]?.label ?? label}
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
nonInteractive
|
||||
>
|
||||
<InputComboBox
|
||||
value={defaultVoice}
|
||||
onValueChange={setDefaultVoice}
|
||||
options={voiceOptions}
|
||||
placeholder={
|
||||
isLoadingVoices
|
||||
? "Loading voices..."
|
||||
: "Select a voice or enter voice ID"
|
||||
}
|
||||
disabled={isLoadingVoices}
|
||||
strict={false}
|
||||
/>
|
||||
</Vertical>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button secondary onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Disabled disabled={!canConnect || isProcessing}>
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
disabled={!canConnect || isProcessing}
|
||||
>
|
||||
{isProcessing ? "Connecting..." : isEditing ? "Save" : "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,630 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
AzureIcon,
|
||||
ElevenLabsIcon,
|
||||
InfoIcon,
|
||||
OpenAIIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { FetchError } from "@/lib/fetcher";
|
||||
import {
|
||||
useVoiceProviders,
|
||||
VoiceProviderView,
|
||||
} from "@/hooks/useVoiceProviders";
|
||||
import {
|
||||
activateVoiceProvider,
|
||||
deactivateVoiceProvider,
|
||||
} from "@/lib/admin/voice/svc";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { cn } from "@/lib/utils";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgArrowRightCircle,
|
||||
SvgAudio,
|
||||
SvgCheckSquare,
|
||||
SvgEdit,
|
||||
SvgMicrophone,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
import VoiceProviderSetupModal from "./VoiceProviderSetupModal";
|
||||
|
||||
interface ModelDetails {
|
||||
id: string;
|
||||
label: string;
|
||||
subtitle: string;
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
}
|
||||
|
||||
interface ProviderGroup {
|
||||
providerType: string;
|
||||
providerLabel: string;
|
||||
logoSrc?: string;
|
||||
models: ModelDetails[];
|
||||
}
|
||||
|
||||
// STT Models - individual cards
|
||||
const STT_MODELS: ModelDetails[] = [
|
||||
{
|
||||
id: "whisper",
|
||||
label: "Whisper",
|
||||
subtitle: "OpenAI's general purpose speech recognition model.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "azure-speech-stt",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Speech to text in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
{
|
||||
id: "elevenlabs-stt",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Speech to Text API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
];
|
||||
|
||||
// TTS Models - grouped by provider
|
||||
const TTS_PROVIDER_GROUPS: ProviderGroup[] = [
|
||||
{
|
||||
providerType: "openai",
|
||||
providerLabel: "OpenAI",
|
||||
logoSrc: "/Openai.svg",
|
||||
models: [
|
||||
{
|
||||
id: "tts-1",
|
||||
label: "TTS-1",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for speed.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
{
|
||||
id: "tts-1-hd",
|
||||
label: "TTS-1 HD",
|
||||
subtitle: "OpenAI's text-to-speech model optimized for quality.",
|
||||
logoSrc: "/Openai.svg",
|
||||
providerType: "openai",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "azure",
|
||||
providerLabel: "Azure",
|
||||
logoSrc: "/Azure.png",
|
||||
models: [
|
||||
{
|
||||
id: "azure-speech-tts",
|
||||
label: "Azure Speech",
|
||||
subtitle: "Text to speech in Microsoft Foundry Tools.",
|
||||
logoSrc: "/Azure.png",
|
||||
providerType: "azure",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
providerType: "elevenlabs",
|
||||
providerLabel: "ElevenLabs",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
models: [
|
||||
{
|
||||
id: "elevenlabs-tts",
|
||||
label: "ElevenAPI",
|
||||
subtitle: "ElevenLabs Text to Speech API.",
|
||||
logoSrc: "/ElevenLabs.svg",
|
||||
providerType: "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
interface HoverIconButtonProps extends React.ComponentProps<typeof Button> {
|
||||
isHovered: boolean;
|
||||
onMouseEnter: () => void;
|
||||
onMouseLeave: () => void;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function HoverIconButton({
|
||||
isHovered,
|
||||
onMouseEnter,
|
||||
onMouseLeave,
|
||||
children,
|
||||
...buttonProps
|
||||
}: HoverIconButtonProps) {
|
||||
return (
|
||||
<div onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Button {...buttonProps} rightIcon={isHovered ? SvgX : SvgCheckSquare}>
|
||||
{children}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
type ProviderMode = "stt" | "tts";
|
||||
|
||||
export default function VoiceConfigurationPage() {
|
||||
const [modalOpen, setModalOpen] = useState(false);
|
||||
const [selectedProvider, setSelectedProvider] = useState<string | null>(null);
|
||||
const [editingProvider, setEditingProvider] =
|
||||
useState<VoiceProviderView | null>(null);
|
||||
const [modalMode, setModalMode] = useState<ProviderMode>("stt");
|
||||
const [selectedModelId, setSelectedModelId] = useState<string | null>(null);
|
||||
const [sttActivationError, setSTTActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [ttsActivationError, setTTSActivationError] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const [hoveredButtonKey, setHoveredButtonKey] = useState<string | null>(null);
|
||||
|
||||
const { providers, error, isLoading, refresh: mutate } = useVoiceProviders();
|
||||
|
||||
const handleConnect = (
|
||||
providerType: string,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(providerType);
|
||||
setEditingProvider(null);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
setSTTActivationError(null);
|
||||
setTTSActivationError(null);
|
||||
};
|
||||
|
||||
const handleEdit = (
|
||||
provider: VoiceProviderView,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
setSelectedProvider(provider.provider_type);
|
||||
setEditingProvider(provider);
|
||||
setModalMode(mode);
|
||||
setSelectedModelId(modelId ?? null);
|
||||
setModalOpen(true);
|
||||
};
|
||||
|
||||
const handleSetDefault = async (
|
||||
providerId: number,
|
||||
mode: ProviderMode,
|
||||
modelId?: string
|
||||
) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await activateVoiceProvider(providerId, mode, modelId);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to set provider as default ${mode.toUpperCase()}.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDeactivate = async (providerId: number, mode: ProviderMode) => {
|
||||
const setError =
|
||||
mode === "stt" ? setSTTActivationError : setTTSActivationError;
|
||||
setError(null);
|
||||
try {
|
||||
const response = await deactivateVoiceProvider(providerId, mode);
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
typeof errorBody?.detail === "string"
|
||||
? errorBody.detail
|
||||
: `Failed to deactivate ${mode.toUpperCase()} provider.`
|
||||
);
|
||||
}
|
||||
await mutate();
|
||||
} catch (err) {
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Unexpected error occurred.";
|
||||
setError(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleModalClose = () => {
|
||||
setModalOpen(false);
|
||||
setSelectedProvider(null);
|
||||
setEditingProvider(null);
|
||||
setSelectedModelId(null);
|
||||
};
|
||||
|
||||
const handleModalSuccess = () => {
|
||||
mutate();
|
||||
handleModalClose();
|
||||
};
|
||||
|
||||
const isProviderConfigured = (provider?: VoiceProviderView): boolean => {
|
||||
return !!provider?.has_api_key;
|
||||
};
|
||||
|
||||
// Map provider types to their configured provider data
|
||||
const providersByType = useMemo(() => {
|
||||
return new Map((providers ?? []).map((p) => [p.provider_type, p] as const));
|
||||
}, [providers]);
|
||||
|
||||
const hasActiveSTTProvider =
|
||||
providers?.some((p) => p.is_default_stt) ?? false;
|
||||
const hasActiveTTSProvider =
|
||||
providers?.some((p) => p.is_default_tts) ?? false;
|
||||
|
||||
const renderLogo = ({
|
||||
logoSrc,
|
||||
providerType,
|
||||
alt,
|
||||
size = 16,
|
||||
}: {
|
||||
logoSrc?: string;
|
||||
providerType: string;
|
||||
alt: string;
|
||||
size?: number;
|
||||
}) => {
|
||||
const containerSizeClass = size === 24 ? "size-7" : "size-5";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-center px-0.5 py-0 shrink-0 overflow-clip",
|
||||
containerSizeClass
|
||||
)}
|
||||
>
|
||||
{providerType === "openai" ? (
|
||||
<OpenAIIcon size={size} />
|
||||
) : providerType === "azure" ? (
|
||||
<AzureIcon size={size} />
|
||||
) : providerType === "elevenlabs" ? (
|
||||
<ElevenLabsIcon size={size} />
|
||||
) : logoSrc ? (
|
||||
<Image
|
||||
src={logoSrc}
|
||||
alt={alt}
|
||||
width={size}
|
||||
height={size}
|
||||
className="object-contain"
|
||||
/>
|
||||
) : (
|
||||
<SvgMicrophone size={size} className="text-text-02" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const renderModelCard = ({
|
||||
model,
|
||||
mode,
|
||||
}: {
|
||||
model: ModelDetails;
|
||||
mode: ProviderMode;
|
||||
}) => {
|
||||
const provider = providersByType.get(model.providerType);
|
||||
const isConfigured = isProviderConfigured(provider);
|
||||
// For TTS, also check that this specific model is the default (not just the provider)
|
||||
const isActive =
|
||||
mode === "stt"
|
||||
? provider?.is_default_stt
|
||||
: provider?.is_default_tts && provider?.tts_model === model.id;
|
||||
const isHighlighted = isActive ?? false;
|
||||
const providerId = provider?.id;
|
||||
|
||||
const buttonState = (() => {
|
||||
if (!provider || !isConfigured) {
|
||||
return {
|
||||
label: "Connect",
|
||||
disabled: false,
|
||||
icon: "arrow" as const,
|
||||
onClick: () => handleConnect(model.providerType, mode, model.id),
|
||||
};
|
||||
}
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
label: "Current Default",
|
||||
disabled: false,
|
||||
icon: "check" as const,
|
||||
onClick: providerId
|
||||
? () => handleDeactivate(providerId, mode)
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
disabled: false,
|
||||
icon: "arrow-circle" as const,
|
||||
onClick: providerId
|
||||
? () => handleSetDefault(providerId, mode, model.id)
|
||||
: undefined,
|
||||
};
|
||||
})();
|
||||
|
||||
const buttonKey = `${mode}-${model.id}`;
|
||||
const isButtonHovered = hoveredButtonKey === buttonKey;
|
||||
const isCardClickable =
|
||||
buttonState.icon === "arrow" &&
|
||||
typeof buttonState.onClick === "function" &&
|
||||
!buttonState.disabled;
|
||||
|
||||
const handleCardClick = () => {
|
||||
if (isCardClickable) {
|
||||
buttonState.onClick?.();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${mode}-${model.id}`}
|
||||
onClick={isCardClickable ? handleCardClick : undefined}
|
||||
className={cn(
|
||||
"flex items-start justify-between gap-4 rounded-16 border p-2 bg-background-neutral-01",
|
||||
isHighlighted ? "border-action-link-05" : "border-border-01",
|
||||
isCardClickable &&
|
||||
"cursor-pointer hover:bg-background-tint-01 transition-colors"
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-1 items-start gap-2.5 p-2">
|
||||
{renderLogo({
|
||||
logoSrc: model.logoSrc,
|
||||
providerType: model.providerType,
|
||||
alt: `${model.label} logo`,
|
||||
size: 16,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text04>
|
||||
{model.label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{model.subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-1.5 self-center">
|
||||
{isConfigured && (
|
||||
<OpalButton
|
||||
icon={SvgEdit}
|
||||
tooltip="Edit"
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (provider) handleEdit(provider, mode, model.id);
|
||||
}}
|
||||
aria-label={`Edit ${model.label}`}
|
||||
/>
|
||||
)}
|
||||
{buttonState.icon === "check" ? (
|
||||
<HoverIconButton
|
||||
isHovered={isButtonHovered}
|
||||
onMouseEnter={() => setHoveredButtonKey(buttonKey)}
|
||||
onMouseLeave={() => setHoveredButtonKey(null)}
|
||||
action={true}
|
||||
tertiary
|
||||
disabled={buttonState.disabled}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
>
|
||||
{buttonState.label}
|
||||
</HoverIconButton>
|
||||
) : (
|
||||
<Button
|
||||
action={false}
|
||||
tertiary
|
||||
disabled={buttonState.disabled || !buttonState.onClick}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
buttonState.onClick?.();
|
||||
}}
|
||||
rightIcon={
|
||||
buttonState.icon === "arrow"
|
||||
? SvgArrowExchange
|
||||
: buttonState.icon === "arrow-circle"
|
||||
? SvgArrowRightCircle
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{buttonState.label}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
if (error) {
|
||||
const message = error?.message || "Unable to load voice configuration.";
|
||||
const detail =
|
||||
error instanceof FetchError && typeof error.info?.detail === "string"
|
||||
? error.info.detail
|
||||
: undefined;
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load voice settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Voice"
|
||||
icon={SvgMicrophone}
|
||||
includeDivider={false}
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle icon={SvgAudio} title="Voice" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" secondaryBody text03>
|
||||
Speech to text (STT) and text to speech (TTS) capabilities.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
{/* Speech-to-Text Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Speech to Text
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to transcribe speech to text in chats.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{sttActivationError && (
|
||||
<Callout type="danger" title="Unable to update STT provider">
|
||||
{sttActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveSTTProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a speech to text provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{STT_MODELS.map((model) => renderModelCard({ model, mode: "stt" }))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Text-to-Speech Section */}
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col">
|
||||
<Text as="p" mainContentEmphasis text04>
|
||||
Text to Speech
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
Select a model to speak out chat responses.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{ttsActivationError && (
|
||||
<Callout type="danger" title="Unable to update TTS provider">
|
||||
{ttsActivationError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
{!hasActiveTTSProvider && (
|
||||
<div
|
||||
className="flex items-start rounded-16 border p-2"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-00)",
|
||||
borderColor: "var(--status-info-02)",
|
||||
}}
|
||||
>
|
||||
<div className="flex items-start gap-1 p-2">
|
||||
<div
|
||||
className="flex size-5 items-center justify-center rounded-full p-0.5"
|
||||
style={{
|
||||
backgroundColor: "var(--status-info-01)",
|
||||
}}
|
||||
>
|
||||
<div style={{ color: "var(--status-text-info-05)" }}>
|
||||
<InfoIcon size={16} />
|
||||
</div>
|
||||
</div>
|
||||
<Text as="p" className="flex-1 px-0.5" mainUiBody text04>
|
||||
Connect a text to speech provider to use in chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col gap-4">
|
||||
{TTS_PROVIDER_GROUPS.map((group) => (
|
||||
<div key={group.providerType} className="flex flex-col gap-2">
|
||||
<Text as="p" secondaryBody text03 className="px-0.5">
|
||||
{group.providerLabel}
|
||||
</Text>
|
||||
<div className="flex flex-col gap-2">
|
||||
{group.models.map((model) =>
|
||||
renderModelCard({ model, mode: "tts" })
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{modalOpen && selectedProvider && (
|
||||
<VoiceProviderSetupModal
|
||||
providerType={selectedProvider}
|
||||
existingProvider={editingProvider}
|
||||
mode={modalMode}
|
||||
defaultModelId={selectedModelId}
|
||||
onClose={handleModalClose}
|
||||
onSuccess={handleModalSuccess}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -3,7 +3,6 @@ import type { Route } from "next";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { requireAuth } from "@/lib/auth/requireAuth";
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
|
||||
export interface LayoutProps {
|
||||
@@ -22,15 +21,10 @@ export default async function Layout({ children }: LayoutProps) {
|
||||
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
{/* VoiceModeProvider wraps the full app layout so TTS playback state
|
||||
persists across page navigations (e.g., sidebar clicks during playback).
|
||||
It only activates WebSocket connections when TTS is actually triggered. */}
|
||||
<VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</VoiceModeProvider>
|
||||
<div className="flex flex-row w-full h-full">
|
||||
<AppSidebar />
|
||||
{children}
|
||||
</div>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useRef,
|
||||
RefObject,
|
||||
useMemo,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
} from "react";
|
||||
import React, { useRef, RefObject, useMemo } from "react";
|
||||
import { Packet, StopReason } from "@/app/app/services/streamingModels";
|
||||
import CustomToolAuthCard from "@/app/app/message/messageComponents/CustomToolAuthCard";
|
||||
import { FullChatState } from "@/app/app/message/messageComponents/interfaces";
|
||||
@@ -22,9 +16,6 @@ import { LlmDescriptor, LlmManager } from "@/lib/hooks";
|
||||
import { Message } from "@/app/app/interfaces";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { AgentTimeline } from "@/app/app/message/messageComponents/timeline/AgentTimeline";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { getTextContent } from "@/app/app/services/packetUtils";
|
||||
import { removeThinkingTokens } from "@/app/app/services/thinkingTokens";
|
||||
|
||||
// Type for the regeneration factory function passed from ChatUI
|
||||
export type RegenerationFactory = (regenerationRequest: {
|
||||
@@ -84,7 +75,6 @@ function arePropsEqual(
|
||||
|
||||
const AgentMessage = React.memo(function AgentMessage({
|
||||
rawPackets,
|
||||
packetCount,
|
||||
chatState,
|
||||
nodeId,
|
||||
messageId,
|
||||
@@ -172,80 +162,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
onMessageSelection,
|
||||
});
|
||||
|
||||
// Streaming TTS integration
|
||||
const { streamTTS, resetTTS, stopTTS } = useVoiceMode();
|
||||
const ttsCompletedRef = useRef(false);
|
||||
const hasStreamedIncompleteRef = useRef(false);
|
||||
const hasObservedPacketGrowthRef = useRef(false);
|
||||
const lastSeenPacketCountRef = useRef(packetCount ?? rawPackets.length);
|
||||
const streamTTSRef = useRef(streamTTS);
|
||||
|
||||
// Keep streamTTS ref in sync without triggering effect re-runs
|
||||
useEffect(() => {
|
||||
streamTTSRef.current = streamTTS;
|
||||
}, [streamTTS]);
|
||||
|
||||
// Stream TTS as text content arrives - only for messages still streaming
|
||||
// Uses ref for streamTTS to avoid re-triggering when its identity changes
|
||||
// Note: packetCount is used instead of rawPackets because the array is mutated in place
|
||||
useLayoutEffect(() => {
|
||||
const effectivePacketCount = packetCount ?? rawPackets.length;
|
||||
if (effectivePacketCount > lastSeenPacketCountRef.current) {
|
||||
hasObservedPacketGrowthRef.current = true;
|
||||
}
|
||||
lastSeenPacketCountRef.current = effectivePacketCount;
|
||||
|
||||
// Skip if we've already finished TTS for this message
|
||||
if (ttsCompletedRef.current) return;
|
||||
|
||||
// If user cancelled generation, do not send more text to TTS.
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
ttsCompletedRef.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const textContent = removeThinkingTokens(getTextContent(rawPackets));
|
||||
if (!(typeof textContent === "string" && textContent.length > 0)) return;
|
||||
|
||||
// Only autoplay messages that were observed streaming in this lifecycle.
|
||||
// Prevents historical, already-complete chats from re-triggering read-aloud on mount.
|
||||
if (!isComplete) {
|
||||
if (!hasObservedPacketGrowthRef.current) {
|
||||
return;
|
||||
}
|
||||
hasStreamedIncompleteRef.current = true;
|
||||
streamTTSRef.current(textContent, false, nodeId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasStreamedIncompleteRef.current) {
|
||||
streamTTSRef.current(textContent, true, nodeId);
|
||||
ttsCompletedRef.current = true;
|
||||
}
|
||||
}, [packetCount, isComplete, rawPackets, nodeId, stopPacketSeen, stopReason]); // packetCount triggers on new packets since rawPackets is mutated in place
|
||||
|
||||
// Stop TTS immediately when user cancels generation.
|
||||
useEffect(() => {
|
||||
if (stopPacketSeen && stopReason === StopReason.USER_CANCELLED) {
|
||||
stopTTS({ manual: true });
|
||||
}
|
||||
}, [stopPacketSeen, stopReason, stopTTS]);
|
||||
|
||||
// Reset TTS completed flag when nodeId changes (new message)
|
||||
useEffect(() => {
|
||||
ttsCompletedRef.current = false;
|
||||
hasStreamedIncompleteRef.current = false;
|
||||
hasObservedPacketGrowthRef.current = false;
|
||||
lastSeenPacketCountRef.current = packetCount ?? rawPackets.length;
|
||||
}, [nodeId]);
|
||||
|
||||
// Reset TTS when component unmounts or nodeId changes
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
resetTTS();
|
||||
};
|
||||
}, [nodeId, resetTTS]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col gap-3"
|
||||
@@ -292,8 +208,6 @@ const AgentMessage = React.memo(function AgentMessage({
|
||||
key={`${displayGroup.turn_index}-${displayGroup.tab_index}`}
|
||||
packets={displayGroup.packets}
|
||||
chatState={effectiveChatState}
|
||||
messageNodeId={nodeId}
|
||||
hasTimelineThinking={pacedTurnGroups.length > 0 || hasSteps}
|
||||
onComplete={() => {
|
||||
// Only mark complete on the last display group
|
||||
// Hook handles the finalAnswerComing check internally
|
||||
|
||||
@@ -29,9 +29,6 @@ import FeedbackModal, {
|
||||
FeedbackModalProps,
|
||||
} from "@/sections/modals/FeedbackModal";
|
||||
import { Button, SelectButton } from "@opal/components";
|
||||
import TTSButton from "./TTSButton";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { useVoiceStatus } from "@/hooks/useVoiceStatus";
|
||||
|
||||
// Wrapper component for SourceTag in toolbar to handle memoization
|
||||
const SourcesTagWrapper = React.memo(function SourcesTagWrapper({
|
||||
@@ -147,14 +144,6 @@ export default function MessageToolbar({
|
||||
(state) => state.updateCurrentSelectedNodeForDocDisplay
|
||||
);
|
||||
|
||||
// Voice mode - hide toolbar during TTS playback for this message
|
||||
const { isTTSPlaying, activeMessageNodeId, isAwaitingAutoPlaybackStart } =
|
||||
useVoiceMode();
|
||||
const { ttsEnabled } = useVoiceStatus();
|
||||
const isTTSActiveForThisMessage =
|
||||
(isTTSPlaying || isAwaitingAutoPlaybackStart) &&
|
||||
activeMessageNodeId === nodeId;
|
||||
|
||||
// Feedback modal state and handlers
|
||||
const { handleFeedbackChange } = useFeedbackController();
|
||||
const modal = useCreateModal();
|
||||
@@ -215,11 +204,6 @@ export default function MessageToolbar({
|
||||
[messageId, currentFeedback, handleFeedbackChange, modal]
|
||||
);
|
||||
|
||||
// Hide toolbar while TTS is playing for this message
|
||||
if (isTTSActiveForThisMessage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<modal.Provider>
|
||||
@@ -284,13 +268,6 @@ export default function MessageToolbar({
|
||||
}
|
||||
data-testid="AgentMessage/dislike-button"
|
||||
/>
|
||||
{ttsEnabled && (
|
||||
<TTSButton
|
||||
text={
|
||||
removeThinkingTokens(getTextContent(rawPackets)) as string
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{onRegenerate &&
|
||||
messageId !== undefined &&
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect } from "react";
|
||||
import { SvgPlayCircle, SvgStop } from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import { useVoicePlayback } from "@/hooks/useVoicePlayback";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
interface TTSButtonProps {
|
||||
text: string;
|
||||
voice?: string;
|
||||
speed?: number;
|
||||
}
|
||||
|
||||
function TTSButton({ text, voice, speed }: TTSButtonProps) {
|
||||
const { isPlaying, isLoading, error, play, pause, stop } = useVoicePlayback();
|
||||
const { isTTSPlaying, isTTSLoading, isAwaitingAutoPlaybackStart, stopTTS } =
|
||||
useVoiceMode();
|
||||
|
||||
const isGlobalTTSActive =
|
||||
isTTSPlaying || isTTSLoading || isAwaitingAutoPlaybackStart;
|
||||
const isButtonPlaying = isGlobalTTSActive || isPlaying;
|
||||
const isButtonLoading = !isGlobalTTSActive && isLoading;
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isGlobalTTSActive) {
|
||||
// Stop auto-playback voice mode stream from the toolbar button.
|
||||
stopTTS({ manual: true });
|
||||
stop();
|
||||
} else if (isPlaying) {
|
||||
pause();
|
||||
} else if (isButtonLoading) {
|
||||
stop();
|
||||
} else {
|
||||
try {
|
||||
// Ensure no voice-mode stream is active before starting manual playback.
|
||||
stopTTS();
|
||||
await play(text, voice, speed);
|
||||
} catch (err) {
|
||||
console.error("TTS playback failed:", err);
|
||||
toast.error("Could not play audio");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isGlobalTTSActive,
|
||||
isPlaying,
|
||||
isButtonLoading,
|
||||
text,
|
||||
voice,
|
||||
speed,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
stopTTS,
|
||||
]);
|
||||
|
||||
// Surface streaming voice playback errors to the user via toast
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice playback error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
const icon = isButtonLoading
|
||||
? SimpleLoader
|
||||
: isButtonPlaying
|
||||
? SvgStop
|
||||
: SvgPlayCircle;
|
||||
|
||||
const tooltip = isButtonPlaying
|
||||
? "Stop playback"
|
||||
: isButtonLoading
|
||||
? "Loading..."
|
||||
: "Read aloud";
|
||||
|
||||
return (
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
prominence="tertiary"
|
||||
tooltip={tooltip}
|
||||
data-testid="AgentMessage/tts-button"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default TTSButton;
|
||||
@@ -67,10 +67,6 @@ export type MessageRenderer<
|
||||
> = React.ComponentType<{
|
||||
packets: T[];
|
||||
state: S;
|
||||
/** Node id for the message currently being rendered */
|
||||
messageNodeId?: number;
|
||||
/** True when timeline/thinking UI is already shown above this text block */
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
renderType: RenderType;
|
||||
animate: boolean;
|
||||
|
||||
@@ -166,8 +166,6 @@ function MixedContentHandler({
|
||||
chatPackets,
|
||||
imagePackets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -177,8 +175,6 @@ function MixedContentHandler({
|
||||
chatPackets: Packet[];
|
||||
imagePackets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -189,8 +185,6 @@ function MixedContentHandler({
|
||||
<MessageTextRenderer
|
||||
packets={chatPackets as ChatPacket[]}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={() => {}}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
@@ -218,8 +212,6 @@ function MixedContentHandler({
|
||||
interface RendererComponentProps {
|
||||
packets: Packet[];
|
||||
chatState: FullChatState;
|
||||
messageNodeId?: number;
|
||||
hasTimelineThinking?: boolean;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
@@ -237,8 +229,7 @@ function areRendererPropsEqual(
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id &&
|
||||
prev.messageNodeId === next.messageNodeId
|
||||
prev.chatState.agent?.id === next.chatState.agent?.id
|
||||
// Skip: onComplete, children (function refs), chatState (memoized upstream)
|
||||
);
|
||||
}
|
||||
@@ -247,8 +238,6 @@ function areRendererPropsEqual(
|
||||
export const RendererComponent = memo(function RendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
@@ -283,8 +272,6 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
chatPackets={chatPackets}
|
||||
imagePackets={imagePackets}
|
||||
chatState={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
@@ -305,8 +292,6 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
messageNodeId={messageNodeId}
|
||||
hasTimelineThinking={hasTimelineThinking}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from "react";
|
||||
import React, { useEffect, useMemo, useState } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
import {
|
||||
@@ -10,55 +10,6 @@ import { MessageRenderer, FullChatState } from "../interfaces";
|
||||
import { isFinalAnswerComplete } from "../../../services/packetUtils";
|
||||
import { useMarkdownRenderer } from "../markdownUtils";
|
||||
import { BlinkingBar } from "../../BlinkingBar";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
/**
|
||||
* Maps a cleaned character position to the corresponding position in markdown text.
|
||||
* This allows progressive reveal to work with markdown formatting.
|
||||
*/
|
||||
function getRevealPosition(markdown: string, cleanChars: number): number {
|
||||
// Skip patterns that don't contribute to visible character count
|
||||
const skipChars = new Set(["*", "`", "#"]);
|
||||
let cleanIndex = 0;
|
||||
let mdIndex = 0;
|
||||
|
||||
while (cleanIndex < cleanChars && mdIndex < markdown.length) {
|
||||
const char = markdown[mdIndex];
|
||||
|
||||
// Skip markdown formatting characters
|
||||
if (char !== undefined && skipChars.has(char)) {
|
||||
mdIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle link syntax [text](url) - skip the (url) part but count the text
|
||||
if (
|
||||
char === "]" &&
|
||||
mdIndex + 1 < markdown.length &&
|
||||
markdown[mdIndex + 1] === "("
|
||||
) {
|
||||
const closeIdx = markdown.indexOf(")", mdIndex + 2);
|
||||
if (closeIdx > 0) {
|
||||
mdIndex = closeIdx + 1;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
cleanIndex++;
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
// Extend to word boundary to avoid cutting mid-word
|
||||
while (
|
||||
mdIndex < markdown.length &&
|
||||
markdown[mdIndex] !== " " &&
|
||||
markdown[mdIndex] !== "\n"
|
||||
) {
|
||||
mdIndex++;
|
||||
}
|
||||
|
||||
return mdIndex;
|
||||
}
|
||||
|
||||
// Control the rate of packet streaming (packets per second)
|
||||
const PACKET_DELAY_MS = 10;
|
||||
@@ -69,8 +20,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
> = ({
|
||||
packets,
|
||||
state,
|
||||
messageNodeId,
|
||||
hasTimelineThinking,
|
||||
onComplete,
|
||||
renderType,
|
||||
animate,
|
||||
@@ -87,17 +36,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
|
||||
const [displayedPacketCount, setDisplayedPacketCount] =
|
||||
useState(initialPacketCount);
|
||||
const lastStableSyncedContentRef = useRef("");
|
||||
const lastVisibleContentRef = useRef("");
|
||||
|
||||
// Get voice mode context for progressive text reveal synced with audio
|
||||
const {
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Get the full content from all packets
|
||||
const fullContent = packets
|
||||
@@ -112,11 +50,6 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
})
|
||||
.join("");
|
||||
|
||||
const shouldUseAutoPlaybackSync =
|
||||
autoPlayback &&
|
||||
typeof messageNodeId === "number" &&
|
||||
activeMessageNodeId === messageNodeId;
|
||||
|
||||
// Animation effect - gradually increase displayed packets at controlled rate
|
||||
useEffect(() => {
|
||||
if (!animate) {
|
||||
@@ -160,37 +93,13 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
}
|
||||
}, [packets, onComplete, animate, displayedPacketCount]);
|
||||
|
||||
// Get content based on displayed packet count or audio progress
|
||||
const computedContent = useMemo(() => {
|
||||
// Hold response in "thinking" state only while autoplay startup is pending.
|
||||
if (shouldUseAutoPlaybackSync && isAwaitingAutoPlaybackStart) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Sync text with audio only for the message currently being spoken.
|
||||
if (shouldUseAutoPlaybackSync && isAudioSyncActive) {
|
||||
const MIN_REVEAL_CHARS = 12;
|
||||
if (revealedCharCount < MIN_REVEAL_CHARS) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Reveal text progressively based on audio progress
|
||||
const revealPos = getRevealPosition(fullContent, revealedCharCount);
|
||||
return fullContent.slice(0, Math.max(revealPos, 0));
|
||||
}
|
||||
|
||||
// During an active synced turn, if sync temporarily drops, keep current reveal
|
||||
// instead of jumping to full content or blanking.
|
||||
if (shouldUseAutoPlaybackSync && !stopPacketSeen) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
// Standard behavior when auto-playback is off
|
||||
// Get content based on displayed packet count
|
||||
const content = useMemo(() => {
|
||||
if (!animate || displayedPacketCount === -1) {
|
||||
return fullContent; // Show all content
|
||||
}
|
||||
|
||||
// Packet-based reveal (when auto-playback is disabled)
|
||||
// Only show content from packets up to displayedPacketCount
|
||||
return packets
|
||||
.slice(0, displayedPacketCount)
|
||||
.map((packet) => {
|
||||
@@ -203,109 +112,31 @@ export const MessageTextRenderer: MessageRenderer<
|
||||
return "";
|
||||
})
|
||||
.join("");
|
||||
}, [
|
||||
animate,
|
||||
displayedPacketCount,
|
||||
fullContent,
|
||||
packets,
|
||||
revealedCharCount,
|
||||
autoPlayback,
|
||||
isAudioSyncActive,
|
||||
activeMessageNodeId,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
messageNodeId,
|
||||
shouldUseAutoPlaybackSync,
|
||||
stopPacketSeen,
|
||||
]);
|
||||
|
||||
// Keep synced text monotonic: once visible, never regress or disappear between chunks.
|
||||
const content = useMemo(() => {
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
// On user cancel, freeze at exactly what was already visible.
|
||||
if (wasUserCancelled) {
|
||||
return lastVisibleContentRef.current;
|
||||
}
|
||||
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
if (computedContent.length === 0) {
|
||||
return lastStableSyncedContentRef.current;
|
||||
}
|
||||
|
||||
const last = lastStableSyncedContentRef.current;
|
||||
if (computedContent.startsWith(last)) {
|
||||
return computedContent;
|
||||
}
|
||||
|
||||
// If content shape changed unexpectedly mid-stream, prefer the stable version
|
||||
// to avoid flicker/dumps.
|
||||
if (!stopPacketSeen || wasUserCancelled) {
|
||||
return last;
|
||||
}
|
||||
|
||||
// For normal completed responses, allow final full content.
|
||||
return computedContent;
|
||||
}, [computedContent, shouldUseAutoPlaybackSync, stopPacketSeen, stopReason]);
|
||||
|
||||
// Sync the stable ref outside of useMemo to avoid side effects during render.
|
||||
useEffect(() => {
|
||||
if (stopReason === StopReason.USER_CANCELLED) {
|
||||
return;
|
||||
}
|
||||
if (!shouldUseAutoPlaybackSync) {
|
||||
lastStableSyncedContentRef.current = "";
|
||||
} else if (content.length > 0) {
|
||||
lastStableSyncedContentRef.current = content;
|
||||
}
|
||||
}, [content, shouldUseAutoPlaybackSync, stopReason]);
|
||||
|
||||
// Track last actually rendered content so cancel can freeze without dumping buffered text.
|
||||
useEffect(() => {
|
||||
if (content.length > 0) {
|
||||
lastVisibleContentRef.current = content;
|
||||
}
|
||||
}, [content]);
|
||||
|
||||
const shouldShowThinkingPlaceholder =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
isAwaitingAutoPlaybackStart &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowSpeechWarmupIndicator =
|
||||
shouldUseAutoPlaybackSync &&
|
||||
!isAwaitingAutoPlaybackStart &&
|
||||
content.length === 0 &&
|
||||
fullContent.length > 0 &&
|
||||
!hasTimelineThinking &&
|
||||
!stopPacketSeen;
|
||||
|
||||
const shouldShowCursor =
|
||||
content.length > 0 &&
|
||||
(!stopPacketSeen ||
|
||||
(shouldUseAutoPlaybackSync && content.length < fullContent.length));
|
||||
}, [animate, displayedPacketCount, fullContent, packets]);
|
||||
|
||||
const { renderedContent } = useMarkdownRenderer(
|
||||
// the [*]() is a hack to show a blinking dot when the packet is not complete
|
||||
shouldShowCursor ? content + " [*]() " : content,
|
||||
stopPacketSeen ? content : content + " [*]() ",
|
||||
state,
|
||||
"font-main-content-body"
|
||||
);
|
||||
|
||||
const wasUserCancelled = stopReason === StopReason.USER_CANCELLED;
|
||||
|
||||
return children([
|
||||
{
|
||||
icon: null,
|
||||
status: null,
|
||||
content:
|
||||
shouldShowThinkingPlaceholder || shouldShowSpeechWarmupIndicator ? (
|
||||
<Text as="span" secondaryBody text04 className="italic">
|
||||
Thinking
|
||||
</Text>
|
||||
) : content.length > 0 ? (
|
||||
<>{renderedContent}</>
|
||||
content.length > 0 || packets.length > 0 ? (
|
||||
<>
|
||||
{renderedContent}
|
||||
{wasUserCancelled && (
|
||||
<Text as="p" secondaryBody text04>
|
||||
User has stopped generation
|
||||
</Text>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<BlinkingBar addMargin />
|
||||
),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useCallback, useState, useEffect, useRef, useMemo } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { track, AnalyticsEvent } from "@/lib/analytics";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import {
|
||||
useSession,
|
||||
useSessionId,
|
||||
@@ -61,6 +61,7 @@ export default function BuildChatPanel({
|
||||
existingSessionId,
|
||||
}: BuildChatPanelProps) {
|
||||
const router = useRouter();
|
||||
const posthog = usePostHog();
|
||||
const outputPanelOpen = useOutputPanelOpen();
|
||||
const session = useSession();
|
||||
const sessionId = useSessionId();
|
||||
@@ -253,7 +254,7 @@ export default function BuildChatPanel({
|
||||
return;
|
||||
}
|
||||
|
||||
track(AnalyticsEvent.SENT_CRAFT_MESSAGE);
|
||||
posthog?.capture("sent_craft_message");
|
||||
|
||||
if (hasSession && sessionId) {
|
||||
// Existing session flow
|
||||
@@ -366,6 +367,7 @@ export default function BuildChatPanel({
|
||||
hasUploadingFiles,
|
||||
limits,
|
||||
refreshLimits,
|
||||
posthog,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { motion } from "motion/react";
|
||||
import { track, AnalyticsEvent } from "@/lib/analytics";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { OnyxLogoTypeIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import BigButton from "@/app/craft/components/BigButton";
|
||||
@@ -16,10 +16,12 @@ export default function BuildModeIntroContent({
|
||||
onClose,
|
||||
onTryBuildMode,
|
||||
}: BuildModeIntroContentProps) {
|
||||
const posthog = usePostHog();
|
||||
|
||||
// Track when user sees the craft intro
|
||||
useEffect(() => {
|
||||
track(AnalyticsEvent.SAW_CRAFT_INTRO);
|
||||
}, []);
|
||||
posthog?.capture("saw_craft_intro");
|
||||
}, [posthog]);
|
||||
|
||||
return (
|
||||
<div className="absolute inset-0 flex flex-col items-center justify-center pointer-events-none">
|
||||
@@ -73,7 +75,7 @@ export default function BuildModeIntroContent({
|
||||
className="!border-white !text-white hover:!bg-white/10 active:!bg-white/20 !w-[160px]"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
track(AnalyticsEvent.CLICKED_GO_HOME);
|
||||
posthog?.capture("clicked_go_home");
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
@@ -84,7 +86,7 @@ export default function BuildModeIntroContent({
|
||||
className="!bg-white !text-black hover:!bg-gray-200 active:!bg-gray-300 !w-[160px]"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
track(AnalyticsEvent.CLICKED_TRY_CRAFT);
|
||||
posthog?.capture("clicked_try_craft");
|
||||
onTryBuildMode();
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useMemo } from "react";
|
||||
import {
|
||||
track,
|
||||
AnalyticsEvent,
|
||||
LLMProviderConfiguredSource,
|
||||
} from "@/lib/analytics";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { SvgArrowRight, SvgArrowLeft, SvgX } from "@opal/icons";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
@@ -116,6 +112,8 @@ export default function BuildOnboardingModal({
|
||||
onLlmComplete,
|
||||
onClose,
|
||||
}: BuildOnboardingModalProps) {
|
||||
const posthog = usePostHog();
|
||||
|
||||
// Compute steps based on mode
|
||||
const steps = useMemo(
|
||||
() => getStepsForMode(mode, isAdmin, allProvidersConfigured, hasUserInfo),
|
||||
@@ -285,12 +283,6 @@ export default function BuildOnboardingModal({
|
||||
modelName: selectedModel,
|
||||
});
|
||||
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: currentProviderConfig.providerName,
|
||||
is_creation: true,
|
||||
source: LLMProviderConfiguredSource.CRAFT_ONBOARDING,
|
||||
});
|
||||
|
||||
setConnectionStatus("success");
|
||||
} catch (error) {
|
||||
console.error("Error connecting LLM provider:", error);
|
||||
@@ -355,7 +347,7 @@ export default function BuildOnboardingModal({
|
||||
level: level || undefined,
|
||||
});
|
||||
|
||||
track(AnalyticsEvent.COMPLETED_CRAFT_ONBOARDING);
|
||||
posthog?.capture("completed_craft_onboarding");
|
||||
onClose();
|
||||
} catch (error) {
|
||||
console.error("Error completing onboarding:", error);
|
||||
@@ -473,7 +465,7 @@ export default function BuildOnboardingModal({
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
track(AnalyticsEvent.COMPLETED_CRAFT_USER_INFO, {
|
||||
posthog?.capture("completed_craft_user_info", {
|
||||
first_name: firstName.trim(),
|
||||
last_name: lastName.trim() || undefined,
|
||||
work_area: workArea,
|
||||
|
||||
@@ -566,21 +566,6 @@ textarea {
|
||||
animation: fadeIn 0.2s ease-out forwards;
|
||||
}
|
||||
|
||||
/* Recording waveform animation */
|
||||
@keyframes waveform {
|
||||
0%,
|
||||
100% {
|
||||
transform: scaleY(0.3);
|
||||
}
|
||||
50% {
|
||||
transform: scaleY(1);
|
||||
}
|
||||
}
|
||||
|
||||
.animate-waveform {
|
||||
animation: waveform 0.8s ease-in-out infinite;
|
||||
}
|
||||
|
||||
.container {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
import { VoiceModeProvider } from "@/providers/VoiceModeProvider";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
@@ -12,9 +11,5 @@ export interface LayoutProps {
|
||||
* Sidebar and chrome are handled by sub-layouts / individual pages.
|
||||
*/
|
||||
export default function Layout({ children }: LayoutProps) {
|
||||
return (
|
||||
<ProjectsProvider>
|
||||
<VoiceModeProvider>{children}</VoiceModeProvider>
|
||||
</ProjectsProvider>
|
||||
);
|
||||
return <ProjectsProvider>{children}</ProjectsProvider>;
|
||||
}
|
||||
|
||||
@@ -39,8 +39,6 @@ import document360Icon from "@public/Document360.png";
|
||||
import dropboxIcon from "@public/Dropbox.png";
|
||||
import drupalwikiIcon from "@public/DrupalWiki.png";
|
||||
import egnyteIcon from "@public/Egnyte.png";
|
||||
import elevenLabsDarkSVG from "@public/ElevenLabsDark.svg";
|
||||
import elevenLabsSVG from "@public/ElevenLabs.svg";
|
||||
import firefliesIcon from "@public/Fireflies.png";
|
||||
import freshdeskIcon from "@public/Freshdesk.png";
|
||||
import geminiSVG from "@public/Gemini.svg";
|
||||
@@ -845,9 +843,6 @@ export const Document360Icon = createLogoIcon(document360Icon);
|
||||
export const DropboxIcon = createLogoIcon(dropboxIcon);
|
||||
export const DrupalWikiIcon = createLogoIcon(drupalwikiIcon);
|
||||
export const EgnyteIcon = createLogoIcon(egnyteIcon);
|
||||
export const ElevenLabsIcon = createLogoIcon(elevenLabsSVG, {
|
||||
darkSrc: elevenLabsDarkSVG,
|
||||
});
|
||||
export const FirefliesIcon = createLogoIcon(firefliesIcon);
|
||||
export const FreshdeskIcon = createLogoIcon(freshdeskIcon);
|
||||
export const GeminiIcon = createLogoIcon(geminiSVG);
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState, useMemo, useRef } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { formatElapsedTime } from "@/lib/dateUtils";
|
||||
import { Button } from "@opal/components";
|
||||
import {
|
||||
SvgMicrophone,
|
||||
SvgMicrophoneOff,
|
||||
SvgVolume,
|
||||
SvgVolumeOff,
|
||||
} from "@opal/icons";
|
||||
|
||||
// Recording waveform constants
|
||||
const RECORDING_BAR_COUNT = 120;
|
||||
const MIN_BAR_HEIGHT = 2;
|
||||
const MAX_BAR_HEIGHT = 16;
|
||||
|
||||
// Speaking waveform constants
|
||||
const SPEAKING_BAR_COUNT = 28;
|
||||
|
||||
interface WaveformProps {
|
||||
/** Visual style and behavior variant */
|
||||
variant: "speaking" | "recording";
|
||||
/** Whether the waveform is actively animating */
|
||||
isActive: boolean;
|
||||
/** Whether audio is muted */
|
||||
isMuted?: boolean;
|
||||
/** Current microphone audio level (0-1), only used for recording variant */
|
||||
audioLevel?: number;
|
||||
/** Callback when mute button is clicked */
|
||||
onMuteToggle?: () => void;
|
||||
}
|
||||
|
||||
function Waveform({
|
||||
variant,
|
||||
isActive,
|
||||
isMuted = false,
|
||||
audioLevel = 0,
|
||||
onMuteToggle,
|
||||
}: WaveformProps) {
|
||||
// ─── Recording variant state ───────────────────────────────────────────────
|
||||
const [elapsedSeconds, setElapsedSeconds] = useState(0);
|
||||
const [barHeights, setBarHeights] = useState<number[]>(
|
||||
() => new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
const animationRef = useRef<number | null>(null);
|
||||
const lastPushTimeRef = useRef(0);
|
||||
const audioLevelRef = useRef(audioLevel);
|
||||
audioLevelRef.current = audioLevel;
|
||||
|
||||
// ─── Speaking variant bars ─────────────────────────────────────────────────
|
||||
const speakingBars = useMemo(() => {
|
||||
return Array.from({ length: SPEAKING_BAR_COUNT }, (_, i) => ({
|
||||
id: i,
|
||||
// Create a natural wave pattern with height variation
|
||||
baseHeight: Math.sin(i * 0.4) * 5 + 8,
|
||||
delay: i * 0.025,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
// ─── Recording: Timer effect ───────────────────────────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setElapsedSeconds(0);
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = setInterval(() => {
|
||||
setElapsedSeconds((prev) => prev + 1);
|
||||
}, 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [variant, isActive]);
|
||||
|
||||
// ─── Recording: Audio level visualization effect ───────────────────────────
|
||||
useEffect(() => {
|
||||
if (variant !== "recording") return;
|
||||
|
||||
if (!isActive) {
|
||||
setBarHeights(
|
||||
new Array(RECORDING_BAR_COUNT).fill(MIN_BAR_HEIGHT) as number[]
|
||||
);
|
||||
lastPushTimeRef.current = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
const updateBars = (timestamp: number) => {
|
||||
// Push a new bar roughly every 50ms (~20fps scrolling)
|
||||
if (timestamp - lastPushTimeRef.current >= 50) {
|
||||
lastPushTimeRef.current = timestamp;
|
||||
const level = isMuted ? 0 : audioLevelRef.current;
|
||||
const height =
|
||||
MIN_BAR_HEIGHT + level * (MAX_BAR_HEIGHT - MIN_BAR_HEIGHT);
|
||||
|
||||
setBarHeights((prev) => {
|
||||
const next = prev.slice(1);
|
||||
next.push(height);
|
||||
return next;
|
||||
});
|
||||
}
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
};
|
||||
|
||||
animationRef.current = requestAnimationFrame(updateBars);
|
||||
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
animationRef.current = null;
|
||||
}
|
||||
};
|
||||
}, [variant, isActive, isMuted]);
|
||||
|
||||
const formattedTime = useMemo(
|
||||
() => formatElapsedTime(elapsedSeconds),
|
||||
[elapsedSeconds]
|
||||
);
|
||||
|
||||
if (!isActive) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// ─── Speaking variant render ───────────────────────────────────────────────
|
||||
if (variant === "speaking") {
|
||||
return (
|
||||
<div className="flex items-center gap-0.5 p-1.5 bg-background-tint-00 rounded-16 shadow-01">
|
||||
{/* Waveform container */}
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12 max-w-[144px] min-h-[32px]">
|
||||
<div className="flex items-center p-1">
|
||||
{/* Waveform bars */}
|
||||
<div className="flex items-center justify-center gap-[2px] h-4 w-[120px] overflow-hidden">
|
||||
{speakingBars.map((bar) => (
|
||||
<div
|
||||
key={bar.id}
|
||||
className={cn(
|
||||
"w-[3px] rounded-full",
|
||||
isMuted ? "bg-text-03" : "bg-theme-blue-05",
|
||||
!isMuted && "animate-waveform"
|
||||
)}
|
||||
style={{
|
||||
height: isMuted ? "2px" : `${bar.baseHeight}px`,
|
||||
animationDelay: isMuted ? undefined : `${bar.delay}s`,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Divider */}
|
||||
<div className="w-0.5 self-stretch bg-border-02" />
|
||||
|
||||
{/* Volume button */}
|
||||
{onMuteToggle && (
|
||||
<div className="flex items-center p-1 bg-background-tint-00 rounded-12">
|
||||
<Button
|
||||
icon={isMuted ? SvgVolumeOff : SvgVolume}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
tooltip={isMuted ? "Unmute" : "Mute"}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Recording variant render ──────────────────────────────────────────────
|
||||
return (
|
||||
<div className="flex items-center gap-3 px-3 py-2 bg-background-tint-00 rounded-12 min-h-[32px]">
|
||||
{/* Waveform visualization driven by real audio levels */}
|
||||
<div className="flex-1 flex items-center justify-between h-4 overflow-hidden">
|
||||
{barHeights.map((height, i) => (
|
||||
<div
|
||||
key={i}
|
||||
className="w-[1.5px] bg-text-03 rounded-full shrink-0 transition-[height] duration-75"
|
||||
style={{ height: `${height}px` }}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* Timer */}
|
||||
<span className="font-mono text-xs text-text-03 tabular-nums shrink-0">
|
||||
{formattedTime}
|
||||
</span>
|
||||
|
||||
{/* Mute button */}
|
||||
{onMuteToggle && (
|
||||
<Button
|
||||
icon={isMuted ? SvgMicrophoneOff : SvgMicrophone}
|
||||
onClick={onMuteToggle}
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
aria-label={isMuted ? "Unmute microphone" : "Mute microphone"}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default Waveform;
|
||||
@@ -58,7 +58,7 @@ import {
|
||||
useRouter,
|
||||
useSearchParams,
|
||||
} from "next/navigation";
|
||||
import { track, AnalyticsEvent } from "@/lib/analytics";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { getExtensionContext } from "@/lib/extension/utils";
|
||||
import useChatSessions from "@/hooks/useChatSessions";
|
||||
import { usePinnedAgents } from "@/hooks/useAgents";
|
||||
@@ -147,6 +147,7 @@ export default function useChatController({
|
||||
const { forcedToolIds } = useForcedTools();
|
||||
const { fetchProjects, setCurrentMessageFiles, beginUpload } =
|
||||
useProjectsContext();
|
||||
const posthog = usePostHog();
|
||||
|
||||
// Use selectors to access only the specific fields we need
|
||||
const currentSessionId = useChatSessionStore(
|
||||
@@ -763,8 +764,8 @@ export default function useChatController({
|
||||
.user_message_id;
|
||||
|
||||
// Track extension queries in PostHog (reuses isExtension/extensionContext from above)
|
||||
if (isExtension) {
|
||||
track(AnalyticsEvent.EXTENSION_CHAT_QUERY, {
|
||||
if (isExtension && posthog) {
|
||||
posthog.capture("extension_chat_query", {
|
||||
extension_context: extensionContext,
|
||||
assistant_id: liveAgent?.id,
|
||||
has_files: effectiveFileDescriptors.length > 0,
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
import { StreamingTTSPlayer } from "@/lib/streamingTTS";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
|
||||
export interface UseVoicePlaybackReturn {
|
||||
isPlaying: boolean;
|
||||
isLoading: boolean;
|
||||
error: string | null;
|
||||
play: (text: string, voice?: string, speed?: number) => Promise<void>;
|
||||
pause: () => void;
|
||||
stop: () => void;
|
||||
}
|
||||
|
||||
export function useVoicePlayback(): UseVoicePlaybackReturn {
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const playerRef = useRef<StreamingTTSPlayer | null>(null);
|
||||
const suppressPlayerErrorsRef = useRef(false);
|
||||
const { setManualTTSPlaying, isTTSMuted, registerManualTTSMuteHandler } =
|
||||
useVoiceMode();
|
||||
|
||||
useEffect(() => {
|
||||
registerManualTTSMuteHandler((muted) => {
|
||||
playerRef.current?.setMuted(muted);
|
||||
});
|
||||
return () => {
|
||||
registerManualTTSMuteHandler(null);
|
||||
};
|
||||
}, [registerManualTTSMuteHandler]);
|
||||
|
||||
const stop = useCallback(() => {
|
||||
suppressPlayerErrorsRef.current = true;
|
||||
if (playerRef.current) {
|
||||
playerRef.current.stop();
|
||||
playerRef.current = null;
|
||||
}
|
||||
setManualTTSPlaying(false);
|
||||
setError(null);
|
||||
setIsPlaying(false);
|
||||
setIsLoading(false);
|
||||
}, [setManualTTSPlaying]);
|
||||
|
||||
const pause = useCallback(() => {
|
||||
// Streaming player currently supports stop/resume via restart, not true pause.
|
||||
stop();
|
||||
}, [stop]);
|
||||
|
||||
const play = useCallback(
|
||||
async (text: string, voice?: string, speed?: number) => {
|
||||
// Stop any existing playback
|
||||
stop();
|
||||
suppressPlayerErrorsRef.current = false;
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
|
||||
try {
|
||||
const player = new StreamingTTSPlayer({
|
||||
onPlayingChange: (playing) => {
|
||||
setIsPlaying(playing);
|
||||
setManualTTSPlaying(playing);
|
||||
if (playing) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
onError: (playbackError) => {
|
||||
if (suppressPlayerErrorsRef.current) {
|
||||
return;
|
||||
}
|
||||
console.error("Voice playback error:", playbackError);
|
||||
setManualTTSPlaying(false);
|
||||
setError(playbackError);
|
||||
setIsLoading(false);
|
||||
setIsPlaying(false);
|
||||
},
|
||||
});
|
||||
playerRef.current = player;
|
||||
player.setMuted(isTTSMuted);
|
||||
|
||||
await player.speak(text, voice, speed);
|
||||
setIsLoading(false);
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
// Request was cancelled, not an error
|
||||
return;
|
||||
}
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Speech synthesis failed";
|
||||
setError(message);
|
||||
setIsLoading(false);
|
||||
setIsPlaying(false);
|
||||
setManualTTSPlaying(false);
|
||||
}
|
||||
},
|
||||
[isTTSMuted, setManualTTSPlaying, stop]
|
||||
);
|
||||
|
||||
return {
|
||||
isPlaying,
|
||||
isLoading,
|
||||
error,
|
||||
play,
|
||||
pause,
|
||||
stop,
|
||||
};
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
export interface VoiceProviderView {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: string;
|
||||
is_default_stt: boolean;
|
||||
is_default_tts: boolean;
|
||||
stt_model: string | null;
|
||||
tts_model: string | null;
|
||||
default_voice: string | null;
|
||||
has_api_key: boolean;
|
||||
target_uri: string | null;
|
||||
}
|
||||
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
export function useVoiceProviders() {
|
||||
const { data, error, isLoading, mutate } = useSWR<VoiceProviderView[]>(
|
||||
VOICE_PROVIDERS_URL,
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
providers: data ?? [],
|
||||
isLoading,
|
||||
error,
|
||||
refresh: mutate,
|
||||
};
|
||||
}
|
||||
@@ -1,525 +0,0 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
// Target format for OpenAI Realtime API
|
||||
const TARGET_SAMPLE_RATE = 24000;
|
||||
const CHUNK_INTERVAL_MS = 250;
|
||||
const DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS = 1500;
|
||||
|
||||
interface TranscriptMessage {
|
||||
type: "transcript" | "error";
|
||||
text?: string;
|
||||
message?: string;
|
||||
is_final?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderOptions {
|
||||
/** Called when VAD detects silence and final transcript is received */
|
||||
onFinalTranscript?: (text: string) => void;
|
||||
/** If true, automatically stop recording when VAD detects silence */
|
||||
autoStopOnSilence?: boolean;
|
||||
}
|
||||
|
||||
export interface UseVoiceRecorderReturn {
|
||||
isRecording: boolean;
|
||||
isProcessing: boolean;
|
||||
isMuted: boolean;
|
||||
error: string | null;
|
||||
liveTranscript: string;
|
||||
/** Current microphone audio level (0-1, RMS-based) */
|
||||
audioLevel: number;
|
||||
startRecording: () => Promise<void>;
|
||||
stopRecording: () => Promise<string | null>;
|
||||
setMuted: (muted: boolean) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Encapsulates all browser resources for a voice recording session.
|
||||
* Manages WebSocket, Web Audio API, and audio buffering.
|
||||
*/
|
||||
class VoiceRecorderSession {
|
||||
// Browser resources
|
||||
private websocket: WebSocket | null = null;
|
||||
private audioContext: AudioContext | null = null;
|
||||
private scriptNode: ScriptProcessorNode | null = null;
|
||||
private sourceNode: MediaStreamAudioSourceNode | null = null;
|
||||
private mediaStream: MediaStream | null = null;
|
||||
private sendInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
// State
|
||||
private audioBuffer: Float32Array[] = [];
|
||||
private transcript = "";
|
||||
private stopResolver: ((text: string | null) => void) | null = null;
|
||||
private isActive = false;
|
||||
// Guard: true once onFinalTranscript has fired for the current utterance.
|
||||
// Prevents the same transcript from being delivered twice when VAD-triggered
|
||||
// stop causes the server to echo the final transcript a second time.
|
||||
private finalTranscriptDelivered = false;
|
||||
private lastDeliveredFinalText: string | null = null;
|
||||
private lastDeliveredFinalAtMs = 0;
|
||||
|
||||
// Callbacks to update React state
|
||||
private onTranscriptChange: (text: string) => void;
|
||||
private onFinalTranscript: ((text: string) => void) | null;
|
||||
private onError: (error: string) => void;
|
||||
private onAudioLevel: (level: number) => void;
|
||||
private onSilenceTimeout: (() => void) | null;
|
||||
private onVADStop: (() => void) | null;
|
||||
private autoStopOnSilence: boolean;
|
||||
|
||||
constructor(
|
||||
onTranscriptChange: (text: string) => void,
|
||||
onFinalTranscript: ((text: string) => void) | null,
|
||||
onError: (error: string) => void,
|
||||
onAudioLevel: (level: number) => void,
|
||||
onSilenceTimeout?: () => void,
|
||||
autoStopOnSilence?: boolean,
|
||||
onVADStop?: () => void
|
||||
) {
|
||||
this.onTranscriptChange = onTranscriptChange;
|
||||
this.onFinalTranscript = onFinalTranscript;
|
||||
this.onError = onError;
|
||||
this.onAudioLevel = onAudioLevel;
|
||||
this.onSilenceTimeout = onSilenceTimeout || null;
|
||||
this.autoStopOnSilence = autoStopOnSilence ?? false;
|
||||
this.onVADStop = onVADStop || null;
|
||||
}
|
||||
|
||||
get recording(): boolean {
|
||||
return this.isActive;
|
||||
}
|
||||
|
||||
get currentTranscript(): string {
|
||||
return this.transcript;
|
||||
}
|
||||
|
||||
setMuted(muted: boolean): void {
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getAudioTracks().forEach((track) => {
|
||||
track.enabled = !muted;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.isActive) return;
|
||||
|
||||
this.cleanup();
|
||||
this.transcript = "";
|
||||
this.audioBuffer = [];
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.lastDeliveredFinalText = null;
|
||||
this.lastDeliveredFinalAtMs = 0;
|
||||
|
||||
// Get microphone
|
||||
this.mediaStream = await navigator.mediaDevices.getUserMedia({
|
||||
audio: {
|
||||
channelCount: 1,
|
||||
sampleRate: { ideal: TARGET_SAMPLE_RATE },
|
||||
echoCancellation: true,
|
||||
noiseSuppression: true,
|
||||
},
|
||||
});
|
||||
|
||||
// Get WS token and connect WebSocket
|
||||
const wsUrl = await this.getWebSocketUrl();
|
||||
this.websocket = new WebSocket(wsUrl);
|
||||
this.websocket.onmessage = this.handleMessage;
|
||||
this.websocket.onerror = () => this.onError("Connection failed");
|
||||
this.websocket.onclose = () => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
};
|
||||
|
||||
await this.waitForConnection();
|
||||
|
||||
// Restore error handler after connection (waitForConnection overwrites it)
|
||||
this.websocket.onerror = () => this.onError("Connection failed");
|
||||
|
||||
// Set up audio capture
|
||||
this.audioContext = new AudioContext({ sampleRate: TARGET_SAMPLE_RATE });
|
||||
this.sourceNode = this.audioContext.createMediaStreamSource(
|
||||
this.mediaStream
|
||||
);
|
||||
this.scriptNode = this.audioContext.createScriptProcessor(4096, 1, 1);
|
||||
|
||||
this.scriptNode.onaudioprocess = (event) => {
|
||||
const inputData = event.inputBuffer.getChannelData(0);
|
||||
this.audioBuffer.push(new Float32Array(inputData));
|
||||
|
||||
// Compute RMS audio level (0-1) for waveform visualization
|
||||
let sum = 0;
|
||||
for (let i = 0; i < inputData.length; i++) {
|
||||
sum += inputData[i]! * inputData[i]!;
|
||||
}
|
||||
const rms = Math.sqrt(sum / inputData.length);
|
||||
// Scale RMS to a more visible range (raw RMS is usually very small)
|
||||
this.onAudioLevel(Math.min(1, rms * 5));
|
||||
};
|
||||
|
||||
this.sourceNode.connect(this.scriptNode);
|
||||
this.scriptNode.connect(this.audioContext.destination);
|
||||
|
||||
// Start sending audio chunks
|
||||
this.sendInterval = setInterval(
|
||||
() => this.sendAudioBuffer(),
|
||||
CHUNK_INTERVAL_MS
|
||||
);
|
||||
this.isActive = true;
|
||||
}
|
||||
|
||||
async stop(): Promise<string | null> {
|
||||
if (!this.isActive) return this.transcript || null;
|
||||
|
||||
// Stop audio capture
|
||||
if (this.sendInterval) {
|
||||
clearInterval(this.sendInterval);
|
||||
this.sendInterval = null;
|
||||
}
|
||||
if (this.scriptNode) {
|
||||
this.scriptNode.disconnect();
|
||||
this.scriptNode = null;
|
||||
}
|
||||
if (this.sourceNode) {
|
||||
this.sourceNode.disconnect();
|
||||
this.sourceNode = null;
|
||||
}
|
||||
if (this.audioContext) {
|
||||
this.audioContext.close();
|
||||
this.audioContext = null;
|
||||
}
|
||||
if (this.mediaStream) {
|
||||
this.mediaStream.getTracks().forEach((track) => track.stop());
|
||||
this.mediaStream = null;
|
||||
}
|
||||
|
||||
this.audioBuffer = [];
|
||||
this.isActive = false;
|
||||
|
||||
// Get final transcript from server
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
return new Promise((resolve) => {
|
||||
this.stopResolver = resolve;
|
||||
this.websocket!.send(JSON.stringify({ type: "end" }));
|
||||
|
||||
// Timeout fallback
|
||||
setTimeout(() => {
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(this.transcript || null);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}, 3000);
|
||||
});
|
||||
}
|
||||
|
||||
return this.transcript || null;
|
||||
}
|
||||
|
||||
cleanup(): void {
|
||||
if (this.sendInterval) clearInterval(this.sendInterval);
|
||||
if (this.scriptNode) this.scriptNode.disconnect();
|
||||
if (this.sourceNode) this.sourceNode.disconnect();
|
||||
if (this.audioContext) this.audioContext.close();
|
||||
if (this.mediaStream) this.mediaStream.getTracks().forEach((t) => t.stop());
|
||||
if (this.websocket) this.websocket.close();
|
||||
|
||||
this.sendInterval = null;
|
||||
this.scriptNode = null;
|
||||
this.sourceNode = null;
|
||||
this.audioContext = null;
|
||||
this.mediaStream = null;
|
||||
this.websocket = null;
|
||||
this.isActive = false;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/transcribe/stream"
|
||||
: "/api/voice/transcribe/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
private waitForConnection(): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (!this.websocket) return reject(new Error("No WebSocket"));
|
||||
|
||||
const timeout = setTimeout(
|
||||
() => reject(new Error("Connection timeout")),
|
||||
5000
|
||||
);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
resolve();
|
||||
};
|
||||
this.websocket.onerror = () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Connection failed"));
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private handleMessage = (event: MessageEvent): void => {
|
||||
try {
|
||||
const data: TranscriptMessage = JSON.parse(event.data);
|
||||
|
||||
if (data.type === "transcript") {
|
||||
if (data.text) {
|
||||
this.transcript = data.text;
|
||||
this.onTranscriptChange(data.text);
|
||||
}
|
||||
|
||||
if (data.is_final && data.text) {
|
||||
// VAD detected silence - trigger callback (only once per utterance)
|
||||
const now = Date.now();
|
||||
const isLikelyDuplicateFinal =
|
||||
this.autoStopOnSilence &&
|
||||
this.lastDeliveredFinalText === data.text &&
|
||||
now - this.lastDeliveredFinalAtMs <
|
||||
DUPLICATE_FINAL_TRANSCRIPT_WINDOW_MS;
|
||||
|
||||
if (
|
||||
this.onFinalTranscript &&
|
||||
!this.finalTranscriptDelivered &&
|
||||
!isLikelyDuplicateFinal
|
||||
) {
|
||||
this.finalTranscriptDelivered = true;
|
||||
this.lastDeliveredFinalText = data.text;
|
||||
this.lastDeliveredFinalAtMs = now;
|
||||
this.onFinalTranscript(data.text);
|
||||
}
|
||||
|
||||
// Auto-stop recording if enabled
|
||||
if (this.autoStopOnSilence) {
|
||||
// Trigger stop callback to update React state
|
||||
if (this.onVADStop) {
|
||||
this.onVADStop();
|
||||
}
|
||||
} else {
|
||||
// If not auto-stopping, reset for next utterance
|
||||
this.transcript = "";
|
||||
this.finalTranscriptDelivered = false;
|
||||
this.onTranscriptChange("");
|
||||
this.resetBackendTranscript();
|
||||
}
|
||||
|
||||
// Resolve stop promise if waiting
|
||||
if (this.stopResolver) {
|
||||
this.stopResolver(data.text);
|
||||
this.stopResolver = null;
|
||||
}
|
||||
}
|
||||
} else if (data.type === "error") {
|
||||
this.onError(data.message || "Transcription error");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error("Failed to parse transcript message:", e);
|
||||
}
|
||||
};
|
||||
|
||||
private resetBackendTranscript(): void {
|
||||
if (this.websocket?.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "reset" }));
|
||||
}
|
||||
}
|
||||
|
||||
private sendAudioBuffer(): void {
|
||||
if (
|
||||
!this.websocket ||
|
||||
this.websocket.readyState !== WebSocket.OPEN ||
|
||||
!this.audioContext ||
|
||||
this.audioBuffer.length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Concatenate buffered chunks
|
||||
const totalLength = this.audioBuffer.reduce(
|
||||
(sum, chunk) => sum + chunk.length,
|
||||
0
|
||||
);
|
||||
|
||||
// Prevent buffer overflow
|
||||
if (totalLength > this.audioContext.sampleRate * 0.5 * 2) {
|
||||
this.audioBuffer = this.audioBuffer.slice(-10);
|
||||
return;
|
||||
}
|
||||
|
||||
const concatenated = new Float32Array(totalLength);
|
||||
let offset = 0;
|
||||
for (const chunk of this.audioBuffer) {
|
||||
concatenated.set(chunk, offset);
|
||||
offset += chunk.length;
|
||||
}
|
||||
this.audioBuffer = [];
|
||||
|
||||
// Resample and convert to PCM16
|
||||
const resampled = this.resampleAudio(
|
||||
concatenated,
|
||||
this.audioContext.sampleRate
|
||||
);
|
||||
const pcm16 = this.float32ToInt16(resampled);
|
||||
|
||||
this.websocket.send(pcm16.buffer);
|
||||
}
|
||||
|
||||
private resampleAudio(input: Float32Array, inputRate: number): Float32Array {
|
||||
if (inputRate === TARGET_SAMPLE_RATE) return input;
|
||||
|
||||
const ratio = inputRate / TARGET_SAMPLE_RATE;
|
||||
const outputLength = Math.round(input.length / ratio);
|
||||
const output = new Float32Array(outputLength);
|
||||
|
||||
for (let i = 0; i < outputLength; i++) {
|
||||
const srcIndex = i * ratio;
|
||||
const floor = Math.floor(srcIndex);
|
||||
const ceil = Math.min(floor + 1, input.length - 1);
|
||||
const fraction = srcIndex - floor;
|
||||
output[i] = input[floor]! * (1 - fraction) + input[ceil]! * fraction;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
private float32ToInt16(float32: Float32Array): Int16Array {
|
||||
const int16 = new Int16Array(float32.length);
|
||||
for (let i = 0; i < float32.length; i++) {
|
||||
const s = Math.max(-1, Math.min(1, float32[i]!));
|
||||
int16[i] = s < 0 ? s * 0x8000 : s * 0x7fff;
|
||||
}
|
||||
return int16;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for voice recording with streaming transcription.
|
||||
*/
|
||||
export function useVoiceRecorder(
|
||||
options?: UseVoiceRecorderOptions
|
||||
): UseVoiceRecorderReturn {
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [isMuted, setIsMutedState] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [liveTranscript, setLiveTranscript] = useState("");
|
||||
const [audioLevel, setAudioLevel] = useState(0);
|
||||
|
||||
const sessionRef = useRef<VoiceRecorderSession | null>(null);
|
||||
const onFinalTranscriptRef = useRef(options?.onFinalTranscript);
|
||||
const autoStopOnSilenceRef = useRef(options?.autoStopOnSilence ?? true); // Default to true
|
||||
|
||||
// Keep callback ref in sync
|
||||
useEffect(() => {
|
||||
onFinalTranscriptRef.current = options?.onFinalTranscript;
|
||||
autoStopOnSilenceRef.current = options?.autoStopOnSilence ?? true;
|
||||
}, [options?.onFinalTranscript, options?.autoStopOnSilence]);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
sessionRef.current?.cleanup();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const startRecording = useCallback(async () => {
|
||||
if (sessionRef.current?.recording) return;
|
||||
|
||||
setError(null);
|
||||
setLiveTranscript("");
|
||||
|
||||
// Clear any stale, inactive session before starting a new one.
|
||||
if (sessionRef.current && !sessionRef.current.recording) {
|
||||
sessionRef.current.cleanup();
|
||||
sessionRef.current = null;
|
||||
}
|
||||
|
||||
// Create VAD stop handler that will stop the session
|
||||
const currentSession = new VoiceRecorderSession(
|
||||
setLiveTranscript,
|
||||
(text) => onFinalTranscriptRef.current?.(text),
|
||||
setError,
|
||||
setAudioLevel,
|
||||
undefined, // onSilenceTimeout
|
||||
autoStopOnSilenceRef.current,
|
||||
() => {
|
||||
// Stop only this session instance, and only clear recording state if it
|
||||
// is still the active session when stop resolves.
|
||||
currentSession.stop().then(() => {
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(false);
|
||||
setIsMutedState(false);
|
||||
sessionRef.current = null;
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
sessionRef.current = currentSession;
|
||||
|
||||
try {
|
||||
await currentSession.start();
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(true);
|
||||
}
|
||||
} catch (err) {
|
||||
currentSession.cleanup();
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to start recording"
|
||||
);
|
||||
if (sessionRef.current === currentSession) {
|
||||
sessionRef.current = null;
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
}, []);
|
||||
|
||||
const stopRecording = useCallback(async (): Promise<string | null> => {
|
||||
if (!sessionRef.current) return null;
|
||||
const currentSession = sessionRef.current;
|
||||
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const transcript = await currentSession.stop();
|
||||
return transcript;
|
||||
} finally {
|
||||
// Only clear state if this is still the active session.
|
||||
if (sessionRef.current === currentSession) {
|
||||
setIsRecording(false);
|
||||
setIsMutedState(false); // Reset mute state when recording stops
|
||||
sessionRef.current = null;
|
||||
}
|
||||
setIsProcessing(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const setMuted = useCallback((muted: boolean) => {
|
||||
setIsMutedState(muted);
|
||||
sessionRef.current?.setMuted(muted);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
audioLevel,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
};
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
interface VoiceStatus {
|
||||
stt_enabled: boolean;
|
||||
tts_enabled: boolean;
|
||||
}
|
||||
|
||||
export function useVoiceStatus() {
|
||||
const { data, error, isLoading } = useSWR<VoiceStatus>(
|
||||
"/api/voice/status",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
dedupingInterval: 60000,
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
sttEnabled: data?.stt_enabled ?? false,
|
||||
ttsEnabled: data?.tts_enabled ?? false,
|
||||
isLoading,
|
||||
error,
|
||||
};
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
|
||||
export type WebSocketStatus =
|
||||
| "connecting"
|
||||
| "connected"
|
||||
| "disconnected"
|
||||
| "error";
|
||||
|
||||
export interface UseWebSocketOptions<T> {
|
||||
/** URL to connect to */
|
||||
url: string;
|
||||
/** Called when a message is received */
|
||||
onMessage?: (data: T) => void;
|
||||
/** Called when connection opens */
|
||||
onOpen?: () => void;
|
||||
/** Called when connection closes */
|
||||
onClose?: () => void;
|
||||
/** Called on error */
|
||||
onError?: (error: Event) => void;
|
||||
/** Auto-connect on mount */
|
||||
autoConnect?: boolean;
|
||||
}
|
||||
|
||||
export interface UseWebSocketReturn<T> {
|
||||
/** Current connection status */
|
||||
status: WebSocketStatus;
|
||||
/** Send JSON data */
|
||||
sendJson: (data: T) => void;
|
||||
/** Send binary data */
|
||||
sendBinary: (data: Blob | ArrayBuffer) => void;
|
||||
/** Connect to WebSocket */
|
||||
connect: () => Promise<void>;
|
||||
/** Disconnect from WebSocket */
|
||||
disconnect: () => void;
|
||||
}
|
||||
|
||||
export function useWebSocket<TReceive = unknown, TSend = unknown>({
|
||||
url,
|
||||
onMessage,
|
||||
onOpen,
|
||||
onClose,
|
||||
onError,
|
||||
autoConnect = false,
|
||||
}: UseWebSocketOptions<TReceive>): UseWebSocketReturn<TSend> {
|
||||
const [status, setStatus] = useState<WebSocketStatus>("disconnected");
|
||||
const wsRef = useRef<WebSocket | null>(null);
|
||||
const onMessageRef = useRef(onMessage);
|
||||
const onOpenRef = useRef(onOpen);
|
||||
const onCloseRef = useRef(onClose);
|
||||
const onErrorRef = useRef(onError);
|
||||
|
||||
// Keep refs updated
|
||||
useEffect(() => {
|
||||
onMessageRef.current = onMessage;
|
||||
onOpenRef.current = onOpen;
|
||||
onCloseRef.current = onClose;
|
||||
onErrorRef.current = onError;
|
||||
}, [onMessage, onOpen, onClose, onError]);
|
||||
|
||||
const connect = useCallback(async (): Promise<void> => {
|
||||
if (
|
||||
wsRef.current?.readyState === WebSocket.OPEN ||
|
||||
wsRef.current?.readyState === WebSocket.CONNECTING
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setStatus("connecting");
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const ws = new WebSocket(url);
|
||||
wsRef.current = ws;
|
||||
|
||||
const timeout = setTimeout(() => {
|
||||
ws.close();
|
||||
reject(new Error("WebSocket connection timeout"));
|
||||
}, 10000);
|
||||
|
||||
ws.onopen = () => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("connected");
|
||||
onOpenRef.current?.();
|
||||
resolve();
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data) as TReceive;
|
||||
onMessageRef.current?.(data);
|
||||
} catch {
|
||||
// Non-JSON message, ignore or handle differently
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("disconnected");
|
||||
onCloseRef.current?.();
|
||||
wsRef.current = null;
|
||||
};
|
||||
|
||||
ws.onerror = (error) => {
|
||||
clearTimeout(timeout);
|
||||
setStatus("error");
|
||||
onErrorRef.current?.(error);
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
});
|
||||
}, [url]);
|
||||
|
||||
const disconnect = useCallback(() => {
|
||||
if (wsRef.current) {
|
||||
wsRef.current.close();
|
||||
wsRef.current = null;
|
||||
}
|
||||
setStatus("disconnected");
|
||||
}, []);
|
||||
|
||||
const sendJson = useCallback((data: TSend) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(JSON.stringify(data));
|
||||
}
|
||||
}, []);
|
||||
|
||||
const sendBinary = useCallback((data: Blob | ArrayBuffer) => {
|
||||
if (wsRef.current?.readyState === WebSocket.OPEN) {
|
||||
wsRef.current.send(data);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Auto-connect if enabled
|
||||
useEffect(() => {
|
||||
if (autoConnect) {
|
||||
connect().catch(() => {
|
||||
// Error handled via onError callback
|
||||
});
|
||||
}
|
||||
return () => {
|
||||
disconnect();
|
||||
};
|
||||
}, [autoConnect, connect, disconnect]);
|
||||
|
||||
return {
|
||||
status,
|
||||
sendJson,
|
||||
sendBinary,
|
||||
connect,
|
||||
disconnect,
|
||||
};
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
const VOICE_PROVIDERS_URL = "/api/admin/voice/providers";
|
||||
|
||||
export async function activateVoiceProvider(
|
||||
providerId: number,
|
||||
mode: "stt" | "tts",
|
||||
ttsModel?: string
|
||||
): Promise<Response> {
|
||||
const url = new URL(
|
||||
`${VOICE_PROVIDERS_URL}/${providerId}/activate-${mode}`,
|
||||
window.location.origin
|
||||
);
|
||||
if (mode === "tts" && ttsModel) {
|
||||
url.searchParams.set("tts_model", ttsModel);
|
||||
}
|
||||
return fetch(url.toString(), { method: "POST" });
|
||||
}
|
||||
|
||||
export async function deactivateVoiceProvider(
|
||||
providerId: number,
|
||||
mode: "stt" | "tts"
|
||||
): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/${providerId}/deactivate-${mode}`, {
|
||||
method: "POST",
|
||||
});
|
||||
}
|
||||
|
||||
export async function testVoiceProvider(request: {
|
||||
provider_type: string;
|
||||
api_key?: string;
|
||||
target_uri?: string;
|
||||
use_stored_key?: boolean;
|
||||
}): Promise<Response> {
|
||||
return fetch(`${VOICE_PROVIDERS_URL}/test`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
|
||||
export async function upsertVoiceProvider(
|
||||
request: Record<string, unknown>
|
||||
): Promise<Response> {
|
||||
return fetch(VOICE_PROVIDERS_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
}
|
||||
|
||||
export async function fetchVoicesByType(
|
||||
providerType: string
|
||||
): Promise<Response> {
|
||||
return fetch(`/api/admin/voice/voices?provider_type=${providerType}`);
|
||||
}
|
||||
|
||||
export async function fetchLLMProviders(): Promise<Response> {
|
||||
return fetch("/api/admin/llm/provider");
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
import posthog from "posthog-js";
|
||||
|
||||
// ─── Event Registry ────────────────────────────────────────────────────────
|
||||
// All tracked event names. Add new events here to get type-safe tracking.
|
||||
|
||||
export enum AnalyticsEvent {
|
||||
CONFIGURED_LLM_PROVIDER = "configured_llm_provider",
|
||||
COMPLETED_CRAFT_ONBOARDING = "completed_craft_onboarding",
|
||||
COMPLETED_CRAFT_USER_INFO = "completed_craft_user_info",
|
||||
SENT_CRAFT_MESSAGE = "sent_craft_message",
|
||||
SAW_CRAFT_INTRO = "saw_craft_intro",
|
||||
CLICKED_GO_HOME = "clicked_go_home",
|
||||
CLICKED_TRY_CRAFT = "clicked_try_craft",
|
||||
CLICKED_CRAFT_IN_SIDEBAR = "clicked_craft_in_sidebar",
|
||||
RELEASE_NOTIFICATION_CLICKED = "release_notification_clicked",
|
||||
EXTENSION_CHAT_QUERY = "extension_chat_query",
|
||||
}
|
||||
|
||||
// ─── Shared Enums ──────────────────────────────────────────────────────────
|
||||
|
||||
export enum LLMProviderConfiguredSource {
|
||||
ADMIN_PAGE = "admin_page",
|
||||
CHAT_ONBOARDING = "chat_onboarding",
|
||||
CRAFT_ONBOARDING = "craft_onboarding",
|
||||
}
|
||||
|
||||
// ─── Event Property Types ──────────────────────────────────────────────────
|
||||
// Maps each event to its required properties. Use `void` for events with no
|
||||
// properties — this makes the second argument to `track()` optional for those
|
||||
// events while requiring it for events that carry data.
|
||||
|
||||
interface AnalyticsEventProperties {
|
||||
[AnalyticsEvent.CONFIGURED_LLM_PROVIDER]: {
|
||||
provider: string;
|
||||
is_creation: boolean;
|
||||
source: LLMProviderConfiguredSource;
|
||||
};
|
||||
[AnalyticsEvent.COMPLETED_CRAFT_ONBOARDING]: void;
|
||||
[AnalyticsEvent.COMPLETED_CRAFT_USER_INFO]: {
|
||||
first_name: string;
|
||||
last_name: string | undefined;
|
||||
work_area: string | undefined;
|
||||
level: string | undefined;
|
||||
};
|
||||
[AnalyticsEvent.SENT_CRAFT_MESSAGE]: void;
|
||||
[AnalyticsEvent.SAW_CRAFT_INTRO]: void;
|
||||
[AnalyticsEvent.CLICKED_GO_HOME]: void;
|
||||
[AnalyticsEvent.CLICKED_TRY_CRAFT]: void;
|
||||
[AnalyticsEvent.CLICKED_CRAFT_IN_SIDEBAR]: void;
|
||||
[AnalyticsEvent.RELEASE_NOTIFICATION_CLICKED]: {
|
||||
version: string | undefined;
|
||||
};
|
||||
[AnalyticsEvent.EXTENSION_CHAT_QUERY]: {
|
||||
extension_context: string | null | undefined;
|
||||
assistant_id: number | undefined;
|
||||
has_files: boolean;
|
||||
deep_research: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Typed Track Function ──────────────────────────────────────────────────
|
||||
|
||||
export function track<E extends AnalyticsEvent>(
|
||||
...args: AnalyticsEventProperties[E] extends void
|
||||
? [event: E]
|
||||
: [event: E, properties: AnalyticsEventProperties[E]]
|
||||
): void {
|
||||
const [event, properties] = args as [E, Record<string, unknown>?];
|
||||
posthog.capture(event, properties ?? {});
|
||||
}
|
||||
@@ -151,17 +151,6 @@ export function formatMmDdYyyy(d: string): string {
|
||||
return `${date.getMonth() + 1}/${date.getDate()}/${date.getFullYear()}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a duration in seconds as MM:SS (e.g. 65 → "01:05").
|
||||
*/
|
||||
export function formatElapsedTime(totalSeconds: number): string {
|
||||
const minutes = Math.floor(totalSeconds / 60);
|
||||
const seconds = totalSeconds % 60;
|
||||
return `${minutes.toString().padStart(2, "0")}:${seconds
|
||||
.toString()
|
||||
.padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
export const getFormattedDateTime = (date: Date | null) => {
|
||||
if (!date) return null;
|
||||
|
||||
|
||||
@@ -1,614 +0,0 @@
|
||||
/**
|
||||
* Real-time streaming TTS using HTTP streaming with MediaSource Extensions.
|
||||
* Plays audio chunks as they arrive for smooth, low-latency playback.
|
||||
*/
|
||||
|
||||
/**
|
||||
* HTTPStreamingTTSPlayer - Uses HTTP streaming with MediaSource Extensions
|
||||
* for smooth, gapless audio playback. This is the recommended approach for
|
||||
* real-time TTS as it properly handles MP3 frame boundaries.
|
||||
*/
|
||||
export class HTTPStreamingTTSPlayer {
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private mediaSourceUrl: string | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private streamComplete: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private abortController: AbortController | null = null;
|
||||
private isMuted: boolean = false;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private getAPIUrl(): string {
|
||||
// Always go through the frontend proxy to ensure cookies are sent correctly
|
||||
// The Next.js proxy at /api/* forwards to the backend
|
||||
return "/api/voice/synthesize";
|
||||
}
|
||||
|
||||
/**
|
||||
* Speak text using HTTP streaming with real-time playback.
|
||||
* Audio begins playing as soon as the first chunks arrive.
|
||||
*/
|
||||
async speak(
|
||||
text: string,
|
||||
voice?: string,
|
||||
speed: number = 1.0
|
||||
): Promise<void> {
|
||||
// Cleanup any previous playback
|
||||
this.cleanup();
|
||||
|
||||
// Create abort controller for this request
|
||||
this.abortController = new AbortController();
|
||||
|
||||
// Build URL with query params
|
||||
const params = new URLSearchParams();
|
||||
params.set("text", text);
|
||||
if (voice) params.set("voice", voice);
|
||||
params.set("speed", speed.toString());
|
||||
|
||||
const url = `${this.getAPIUrl()}?${params}`;
|
||||
|
||||
// Check if MediaSource is supported
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
// Fallback to simple buffered playback
|
||||
return this.fallbackSpeak(url);
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.mediaSourceUrl = URL.createObjectURL(this.mediaSource);
|
||||
this.audioElement.src = this.mediaSourceUrl;
|
||||
this.audioElement.muted = this.isMuted;
|
||||
|
||||
// Set up audio element event handlers
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
if (!this.mediaSource) {
|
||||
reject(new Error("MediaSource not initialized"));
|
||||
return;
|
||||
}
|
||||
|
||||
this.mediaSource.onsourceopen = () => {
|
||||
try {
|
||||
// Create SourceBuffer for MP3
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
|
||||
// MediaSource doesn't have onerror in all browsers, use onsourceclose as fallback
|
||||
this.mediaSource.onsourceclose = () => {
|
||||
if (this.mediaSource?.readyState === "closed") {
|
||||
reject(new Error("MediaSource closed unexpectedly"));
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Start fetching and streaming audio
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`TTS request failed: ${response.status} - ${errorText}`
|
||||
);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
// Start playback as soon as we have some data
|
||||
let firstChunk = true;
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
this.streamComplete = true;
|
||||
// End the stream when all chunks are appended
|
||||
this.finalizeStream();
|
||||
break;
|
||||
}
|
||||
|
||||
if (value) {
|
||||
this.pendingChunks.push(value);
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (firstChunk && this.audioElement) {
|
||||
firstChunk = false;
|
||||
// Small delay to buffer a bit before starting
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback start errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
this.onError?.(err instanceof Error ? err.message : "TTS error");
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process next chunk from the queue.
|
||||
*/
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
// Try next chunk
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize the stream when all data has been received.
|
||||
*/
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
// Wait for remaining chunks to be appended
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore errors when ending stream
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fallback for browsers that don't support MediaSource Extensions.
|
||||
* Buffers all audio before playing.
|
||||
*/
|
||||
private async fallbackSpeak(url: string): Promise<void> {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
signal: this.abortController?.signal,
|
||||
credentials: "include", // Include cookies for authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`TTS request failed: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const audioData = await response.arrayBuffer();
|
||||
|
||||
const blob = new Blob([audioData], { type: "audio/mpeg" });
|
||||
const audioUrl = URL.createObjectURL(blob);
|
||||
|
||||
this.audioElement = new Audio(audioUrl);
|
||||
this.audioElement.muted = this.isMuted;
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
URL.revokeObjectURL(audioUrl);
|
||||
};
|
||||
|
||||
this.audioElement.onerror = () => {
|
||||
this.onError?.("Audio playback error");
|
||||
};
|
||||
|
||||
await this.audioElement.play();
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop playback and cleanup resources.
|
||||
*/
|
||||
stop(): void {
|
||||
// Abort any ongoing request
|
||||
if (this.abortController) {
|
||||
this.abortController.abort();
|
||||
this.abortController = null;
|
||||
}
|
||||
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
setMuted(muted: boolean): void {
|
||||
this.isMuted = muted;
|
||||
if (this.audioElement) {
|
||||
this.audioElement.muted = muted;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup all resources.
|
||||
*/
|
||||
private cleanup(): void {
|
||||
// Revoke Object URL to prevent memory leak
|
||||
if (this.mediaSourceUrl) {
|
||||
URL.revokeObjectURL(this.mediaSourceUrl);
|
||||
this.mediaSourceUrl = null;
|
||||
}
|
||||
|
||||
// Stop and cleanup audio element
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
// Cleanup MediaSource
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.streamComplete = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* WebSocketStreamingTTSPlayer - Uses WebSocket for bidirectional streaming.
|
||||
* Useful for scenarios where you want to stream text in and get audio out
|
||||
* incrementally (e.g., as LLM generates text).
|
||||
*/
|
||||
export class WebSocketStreamingTTSPlayer {
|
||||
private websocket: WebSocket | null = null;
|
||||
private mediaSource: MediaSource | null = null;
|
||||
private mediaSourceUrl: string | null = null;
|
||||
private sourceBuffer: SourceBuffer | null = null;
|
||||
private audioElement: HTMLAudioElement | null = null;
|
||||
private pendingChunks: Uint8Array[] = [];
|
||||
private isAppending: boolean = false;
|
||||
private isPlaying: boolean = false;
|
||||
private onPlayingChange?: (playing: boolean) => void;
|
||||
private onError?: (error: string) => void;
|
||||
private hasStartedPlayback: boolean = false;
|
||||
|
||||
constructor(options?: {
|
||||
onPlayingChange?: (playing: boolean) => void;
|
||||
onError?: (error: string) => void;
|
||||
}) {
|
||||
this.onPlayingChange = options?.onPlayingChange;
|
||||
this.onError = options?.onError;
|
||||
}
|
||||
|
||||
private async getWebSocketUrl(): Promise<string> {
|
||||
// Fetch short-lived WS token
|
||||
const tokenResponse = await fetch("/api/voice/ws-token", {
|
||||
method: "POST",
|
||||
credentials: "include",
|
||||
});
|
||||
if (!tokenResponse.ok) {
|
||||
throw new Error("Failed to get WebSocket authentication token");
|
||||
}
|
||||
const { token } = await tokenResponse.json();
|
||||
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
const isDev = window.location.port === "3000";
|
||||
const host = isDev ? "localhost:8080" : window.location.host;
|
||||
const path = isDev
|
||||
? "/voice/synthesize/stream"
|
||||
: "/api/voice/synthesize/stream";
|
||||
return `${protocol}//${host}${path}?token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
||||
async connect(voice?: string, speed?: number): Promise<void> {
|
||||
// Cleanup any previous connection
|
||||
this.cleanup();
|
||||
|
||||
// Check MediaSource support
|
||||
if (!window.MediaSource || !MediaSource.isTypeSupported("audio/mpeg")) {
|
||||
throw new Error("MediaSource Extensions not supported");
|
||||
}
|
||||
|
||||
// Create MediaSource and audio element
|
||||
this.mediaSource = new MediaSource();
|
||||
this.audioElement = new Audio();
|
||||
this.mediaSourceUrl = URL.createObjectURL(this.mediaSource);
|
||||
this.audioElement.src = this.mediaSourceUrl;
|
||||
|
||||
this.audioElement.onplay = () => {
|
||||
if (!this.isPlaying) {
|
||||
this.isPlaying = true;
|
||||
this.onPlayingChange?.(true);
|
||||
}
|
||||
};
|
||||
|
||||
this.audioElement.onended = () => {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
};
|
||||
|
||||
// Wait for MediaSource to be ready
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
this.mediaSource!.onsourceopen = () => {
|
||||
try {
|
||||
this.sourceBuffer = this.mediaSource!.addSourceBuffer("audio/mpeg");
|
||||
this.sourceBuffer.mode = "sequence";
|
||||
this.sourceBuffer.onupdateend = () => {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
};
|
||||
resolve();
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Connect WebSocket
|
||||
const url = await this.getWebSocketUrl();
|
||||
return new Promise((resolve, reject) => {
|
||||
this.websocket = new WebSocket(url);
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
// Send initial config
|
||||
this.websocket?.send(
|
||||
JSON.stringify({
|
||||
type: "config",
|
||||
voice: voice,
|
||||
speed: speed || 1.0,
|
||||
})
|
||||
);
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.websocket.onerror = () => {
|
||||
reject(new Error("WebSocket connection failed"));
|
||||
};
|
||||
|
||||
this.websocket.onmessage = async (event) => {
|
||||
if (event.data instanceof Blob) {
|
||||
// Audio chunk received
|
||||
const arrayBuffer = await event.data.arrayBuffer();
|
||||
this.pendingChunks.push(new Uint8Array(arrayBuffer));
|
||||
this.processNextChunk();
|
||||
|
||||
// Start playback after first chunk
|
||||
if (!this.hasStartedPlayback && this.audioElement) {
|
||||
this.hasStartedPlayback = true;
|
||||
setTimeout(() => {
|
||||
this.audioElement?.play().catch(() => {
|
||||
// Ignore playback errors
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
} else {
|
||||
// JSON message
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type === "audio_done") {
|
||||
this.finalizeStream();
|
||||
} else if (data.type === "error") {
|
||||
this.onError?.(data.message);
|
||||
}
|
||||
} catch {
|
||||
// Ignore parse errors
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
this.websocket.onclose = () => {
|
||||
this.finalizeStream();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private processNextChunk(): void {
|
||||
if (
|
||||
this.isAppending ||
|
||||
this.pendingChunks.length === 0 ||
|
||||
!this.sourceBuffer ||
|
||||
this.sourceBuffer.updating
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = this.pendingChunks.shift();
|
||||
if (chunk) {
|
||||
this.isAppending = true;
|
||||
try {
|
||||
// Use ArrayBuffer directly for better TypeScript compatibility
|
||||
const buffer = chunk.buffer.slice(
|
||||
chunk.byteOffset,
|
||||
chunk.byteOffset + chunk.byteLength
|
||||
) as ArrayBuffer;
|
||||
this.sourceBuffer.appendBuffer(buffer);
|
||||
} catch {
|
||||
this.isAppending = false;
|
||||
this.processNextChunk();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private finalizeStream(): void {
|
||||
if (this.pendingChunks.length > 0 || this.isAppending) {
|
||||
setTimeout(() => this.finalizeStream(), 50);
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
this.mediaSource &&
|
||||
this.mediaSource.readyState === "open" &&
|
||||
this.sourceBuffer &&
|
||||
!this.sourceBuffer.updating
|
||||
) {
|
||||
try {
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async speak(text: string): Promise<void> {
|
||||
if (!this.websocket || this.websocket.readyState !== WebSocket.OPEN) {
|
||||
throw new Error("WebSocket not connected");
|
||||
}
|
||||
|
||||
this.websocket.send(
|
||||
JSON.stringify({
|
||||
type: "synthesize",
|
||||
text: text,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
stop(): void {
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
disconnect(): void {
|
||||
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
|
||||
this.websocket.send(JSON.stringify({ type: "end" }));
|
||||
this.websocket.close();
|
||||
}
|
||||
this.cleanup();
|
||||
}
|
||||
|
||||
private cleanup(): void {
|
||||
if (this.websocket) {
|
||||
this.websocket.close();
|
||||
this.websocket = null;
|
||||
}
|
||||
|
||||
// Revoke Object URL to prevent memory leak
|
||||
if (this.mediaSourceUrl) {
|
||||
URL.revokeObjectURL(this.mediaSourceUrl);
|
||||
this.mediaSourceUrl = null;
|
||||
}
|
||||
|
||||
if (this.audioElement) {
|
||||
this.audioElement.pause();
|
||||
this.audioElement.src = "";
|
||||
this.audioElement = null;
|
||||
}
|
||||
|
||||
if (this.mediaSource && this.mediaSource.readyState === "open") {
|
||||
try {
|
||||
if (this.sourceBuffer) {
|
||||
this.mediaSource.removeSourceBuffer(this.sourceBuffer);
|
||||
}
|
||||
this.mediaSource.endOfStream();
|
||||
} catch {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
|
||||
this.mediaSource = null;
|
||||
this.sourceBuffer = null;
|
||||
this.pendingChunks = [];
|
||||
this.isAppending = false;
|
||||
this.hasStartedPlayback = false;
|
||||
|
||||
if (this.isPlaying) {
|
||||
this.isPlaying = false;
|
||||
this.onPlayingChange?.(false);
|
||||
}
|
||||
}
|
||||
|
||||
get playing(): boolean {
|
||||
return this.isPlaying;
|
||||
}
|
||||
}
|
||||
|
||||
// Export the HTTP player as the default/recommended option
|
||||
export { HTTPStreamingTTSPlayer as StreamingTTSPlayer };
|
||||
@@ -32,10 +32,6 @@ interface UserPreferences {
|
||||
theme_preference: ThemePreference | null;
|
||||
chat_background: string | null;
|
||||
default_app_mode: "AUTO" | "CHAT" | "SEARCH";
|
||||
// Voice preferences
|
||||
voice_auto_send?: boolean;
|
||||
voice_auto_playback?: boolean;
|
||||
voice_playback_speed?: number;
|
||||
}
|
||||
|
||||
export interface MemoryItem {
|
||||
|
||||
@@ -8,7 +8,6 @@ import { toast, toastStore, MAX_VISIBLE_TOASTS } from "@/hooks/useToast";
|
||||
import type { Toast, ToastLevel } from "@/hooks/useToast";
|
||||
|
||||
const ANIMATION_DURATION = 200; // matches tailwind fade-out-scale (0.2s)
|
||||
const MAX_TOAST_MESSAGE_LENGTH = 150;
|
||||
|
||||
function levelProps(level: ToastLevel): Record<string, boolean> {
|
||||
switch (level) {
|
||||
@@ -59,36 +58,29 @@ function ToastContainer() {
|
||||
data-testid="toast-container"
|
||||
className={cn(
|
||||
"fixed bottom-4 right-4 z-[10000]",
|
||||
"flex flex-col gap-2 items-end",
|
||||
"max-w-[420px]"
|
||||
"flex flex-col gap-2 items-end"
|
||||
)}
|
||||
>
|
||||
{visible.map((t) => {
|
||||
const text =
|
||||
t.message.length > MAX_TOAST_MESSAGE_LENGTH
|
||||
? t.message.slice(0, MAX_TOAST_MESSAGE_LENGTH) + "…"
|
||||
: t.message;
|
||||
return (
|
||||
<div
|
||||
key={t.id}
|
||||
className={cn(
|
||||
t.leaving ? "animate-fade-out-scale" : "animate-fade-in-scale"
|
||||
)}
|
||||
>
|
||||
<Message
|
||||
flash
|
||||
medium
|
||||
{...levelProps(t.level ?? "info")}
|
||||
text={text}
|
||||
description={buildDescription(t)}
|
||||
close={t.dismissible}
|
||||
onClose={() => handleClose(t.id)}
|
||||
actions={t.actionLabel ? t.actionLabel : undefined}
|
||||
onAction={t.onAction}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
{visible.map((t) => (
|
||||
<div
|
||||
key={t.id}
|
||||
className={cn(
|
||||
t.leaving ? "animate-fade-out-scale" : "animate-fade-in-scale"
|
||||
)}
|
||||
>
|
||||
<Message
|
||||
flash
|
||||
medium
|
||||
{...levelProps(t.level ?? "info")}
|
||||
text={t.message}
|
||||
description={buildDescription(t)}
|
||||
close={t.dismissible}
|
||||
onClose={() => handleClose(t.id)}
|
||||
actions={t.actionLabel ? t.actionLabel : undefined}
|
||||
onAction={t.onAction}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -46,11 +46,6 @@ interface UserContextType {
|
||||
updateUserChatBackground: (chatBackground: string | null) => Promise<void>;
|
||||
updateUserDefaultModel: (defaultModel: string | null) => Promise<void>;
|
||||
updateUserDefaultAppMode: (mode: "CHAT" | "SEARCH") => Promise<void>;
|
||||
updateUserVoiceSettings: (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => Promise<void>;
|
||||
}
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
@@ -465,50 +460,6 @@ export function UserProvider({
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserVoiceSettings = async (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
if (prevUser) {
|
||||
return {
|
||||
...prevUser,
|
||||
preferences: {
|
||||
...prevUser.preferences,
|
||||
voice_auto_send:
|
||||
settings.auto_send ?? prevUser.preferences.voice_auto_send,
|
||||
voice_auto_playback:
|
||||
settings.auto_playback ??
|
||||
prevUser.preferences.voice_auto_playback,
|
||||
voice_playback_speed:
|
||||
settings.playback_speed ??
|
||||
prevUser.preferences.voice_playback_speed,
|
||||
},
|
||||
};
|
||||
}
|
||||
return prevUser;
|
||||
});
|
||||
|
||||
const response = await fetch("/api/voice/settings", {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(settings),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
await refreshUser();
|
||||
throw new Error("Failed to update voice settings");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating voice settings:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const refreshUser = async () => {
|
||||
await fetchUser();
|
||||
};
|
||||
@@ -527,7 +478,6 @@ export function UserProvider({
|
||||
updateUserChatBackground,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
toggleAgentPinnedStatus,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -124,7 +124,7 @@ function InputChipField({
|
||||
value={value}
|
||||
onChange={(e) => onChange(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={chips.length === 0 ? placeholder : undefined}
|
||||
placeholder={placeholder}
|
||||
className={cn(
|
||||
"flex-1 min-w-[80px] h-[1.5rem] bg-transparent p-0.5 focus:outline-none",
|
||||
innerClasses[variant],
|
||||
|
||||
@@ -210,15 +210,10 @@ describe("InputComboBox", () => {
|
||||
|
||||
await user.type(input, "app");
|
||||
|
||||
// In non-strict mode, searching shows:
|
||||
// 1) a create option for the current input and
|
||||
// 2) matched options.
|
||||
// Search should only show matching options by default
|
||||
const options = screen.getAllByRole("option");
|
||||
expect(options.length).toBe(2);
|
||||
expect(screen.getByLabelText('Create "app"')).toBeInTheDocument();
|
||||
expect(
|
||||
options.some((option) => option.textContent?.includes("Apple"))
|
||||
).toBe(true);
|
||||
expect(options.length).toBe(1);
|
||||
expect(options[0]!.textContent).toBe("Apple");
|
||||
expect(screen.queryByText("Banana")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
|
||||
@@ -130,7 +130,6 @@ const InputComboBox = ({
|
||||
leftSearchIcon = false,
|
||||
rightSection,
|
||||
separatorLabel = "Other options",
|
||||
showAddPrefix = false,
|
||||
showOtherOptions = false,
|
||||
...rest
|
||||
}: WithoutStyles<InputComboBoxProps>) => {
|
||||
@@ -158,11 +157,14 @@ const InputComboBox = ({
|
||||
const visibleUnmatchedOptions =
|
||||
hasSearchTerm && showOtherOptions ? unmatchedOptions : [];
|
||||
|
||||
// Whether to show the create option (always show when typing in non-strict mode)
|
||||
const showCreateOption = !strict && hasSearchTerm && inputValue.trim() !== "";
|
||||
// Whether to show the create option (only when no partial matches)
|
||||
const showCreateOption =
|
||||
!strict &&
|
||||
hasSearchTerm &&
|
||||
inputValue.trim() !== "" &&
|
||||
matchedOptions.length === 0;
|
||||
|
||||
// Combined list for keyboard navigation (includes create option when shown)
|
||||
// Only show matched options when searching (hide unmatched)
|
||||
const allVisibleOptions = useMemo(() => {
|
||||
const baseOptions = [...matchedOptions, ...visibleUnmatchedOptions];
|
||||
if (showCreateOption) {
|
||||
@@ -448,7 +450,6 @@ const InputComboBox = ({
|
||||
inputValue={inputValue}
|
||||
allowCreate={!strict}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</>
|
||||
|
||||
|
||||
@@ -27,8 +27,6 @@ interface ComboBoxDropdownProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -60,7 +58,6 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
@@ -135,7 +132,6 @@ export const ComboBoxDropdown = forwardRef<
|
||||
inputValue={inputValue}
|
||||
allowCreate={allowCreate}
|
||||
showCreateOption={showCreateOption}
|
||||
showAddPrefix={showAddPrefix}
|
||||
/>
|
||||
</div>,
|
||||
document.body
|
||||
|
||||
@@ -24,8 +24,6 @@ interface OptionsListProps {
|
||||
allowCreate: boolean;
|
||||
/** Whether to show create option (pre-computed by parent) */
|
||||
showCreateOption: boolean;
|
||||
/** Show "Add" prefix in create option */
|
||||
showAddPrefix: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -47,7 +45,6 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
inputValue,
|
||||
allowCreate,
|
||||
showCreateOption,
|
||||
showAddPrefix,
|
||||
}) => {
|
||||
// Index offset for other options when create option is shown
|
||||
const indexOffset = showCreateOption ? 1 : 0;
|
||||
@@ -73,7 +70,7 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
data-index={0}
|
||||
role="option"
|
||||
aria-selected={false}
|
||||
aria-label={`${showAddPrefix ? "Add" : "Create"} "${inputValue}"`}
|
||||
aria-label={`Create "${inputValue}"`}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onSelect({ value: inputValue, label: inputValue });
|
||||
@@ -84,48 +81,19 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
onMouseEnter={() => onMouseEnter(0)}
|
||||
onMouseMove={onMouseMove}
|
||||
className={cn(
|
||||
"cursor-pointer transition-colors",
|
||||
"px-3 py-2 cursor-pointer transition-colors",
|
||||
"flex items-center justify-between rounded-08",
|
||||
highlightedIndex === 0 && "bg-background-tint-02",
|
||||
"hover:bg-background-tint-02",
|
||||
showAddPrefix ? "px-1.5 py-1.5" : "px-3 py-2"
|
||||
"hover:bg-background-tint-02"
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"font-main-ui-action truncate min-w-0",
|
||||
showAddPrefix ? "px-1" : ""
|
||||
)}
|
||||
>
|
||||
{showAddPrefix ? (
|
||||
<>
|
||||
<span className="text-text-03">Add</span>
|
||||
<span className="text-text-04">{` ${inputValue}`}</span>
|
||||
</>
|
||||
) : (
|
||||
<span className="text-text-04">{inputValue}</span>
|
||||
)}
|
||||
<span className="font-main-ui-action text-text-04 truncate min-w-0">
|
||||
{inputValue}
|
||||
</span>
|
||||
<SvgPlus
|
||||
className={cn(
|
||||
"w-4 h-4 flex-shrink-0",
|
||||
showAddPrefix ? "text-text-04 mx-1" : "text-text-03 ml-2"
|
||||
)}
|
||||
/>
|
||||
<SvgPlus className="w-4 h-4 text-text-03 flex-shrink-0 ml-2" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Separator - show when there are options to display */}
|
||||
{separatorLabel &&
|
||||
(matchedOptions.length > 0 ||
|
||||
(!hasSearchTerm && unmatchedOptions.length > 0)) && (
|
||||
<div className="px-3 py-1">
|
||||
<Text as="p" text03 secondaryBody>
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Matched/Filtered Options */}
|
||||
{matchedOptions.map((option, idx) => {
|
||||
const globalIndex = idx + indexOffset;
|
||||
@@ -148,27 +116,37 @@ export const OptionsList: React.FC<OptionsListProps> = ({
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Unmatched Options - only show when NOT searching */}
|
||||
{!hasSearchTerm &&
|
||||
unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{/* Separator - only show if there are unmatched options and a search term */}
|
||||
{hasSearchTerm && unmatchedOptions.length > 0 && (
|
||||
<div className="px-3 py-2 pt-3">
|
||||
<div className="border-t border-border-01 pt-2">
|
||||
<Text as="p" text04 secondaryBody className="text-text-02">
|
||||
{separatorLabel}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Unmatched Options */}
|
||||
{unmatchedOptions.map((option, idx) => {
|
||||
const globalIndex = matchedOptions.length + idx + indexOffset;
|
||||
const isExact = isExactMatch(option);
|
||||
return (
|
||||
<OptionItem
|
||||
key={option.value}
|
||||
option={option}
|
||||
index={globalIndex}
|
||||
fieldId={fieldId}
|
||||
isHighlighted={globalIndex === highlightedIndex}
|
||||
isSelected={value === option.value}
|
||||
isExact={isExact}
|
||||
onSelect={onSelect}
|
||||
onMouseEnter={onMouseEnter}
|
||||
onMouseMove={onMouseMove}
|
||||
searchTerm={inputValue}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import { useState, useEffect, useCallback, useMemo, RefObject } from "react";
|
||||
import { ComboBoxOption } from "./types";
|
||||
|
||||
// =============================================================================
|
||||
@@ -19,7 +19,6 @@ export function useComboBoxState({ value, options }: UseComboBoxStateProps) {
|
||||
const [inputValue, setInputValue] = useState(value);
|
||||
const [highlightedIndex, setHighlightedIndex] = useState(-1);
|
||||
const [isKeyboardNav, setIsKeyboardNav] = useState(false);
|
||||
const prevIsOpenRef = useRef(false);
|
||||
|
||||
// Sync inputValue with the external value prop.
|
||||
// When the dropdown is closed, always reflect the controlled value.
|
||||
|
||||
@@ -40,8 +40,6 @@ export interface InputComboBoxProps
|
||||
rightSection?: React.ReactNode;
|
||||
/** Label for the separator between matched and unmatched options */
|
||||
separatorLabel?: string;
|
||||
/** Show "Add" prefix in create option (e.g., "Add [value]") */
|
||||
showAddPrefix?: boolean;
|
||||
/**
|
||||
* When true, keep non-matching options visible under a separator while searching.
|
||||
* Defaults to false so search results are strictly filtered.
|
||||
|
||||
@@ -751,7 +751,6 @@ function ChatPreferencesSettings() {
|
||||
updateUserShortcuts,
|
||||
updateUserDefaultModel,
|
||||
updateUserDefaultAppMode,
|
||||
updateUserVoiceSettings,
|
||||
} = useUser();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
@@ -768,43 +767,6 @@ function ChatPreferencesSettings() {
|
||||
onSuccess: () => toast.success("Preferences saved"),
|
||||
onError: () => toast.error("Failed to save preferences"),
|
||||
});
|
||||
const [draftVoicePlaybackSpeed, setDraftVoicePlaybackSpeed] = useState(
|
||||
user?.preferences.voice_playback_speed ?? 1
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setDraftVoicePlaybackSpeed(user?.preferences.voice_playback_speed ?? 1);
|
||||
}, [user?.preferences.voice_playback_speed]);
|
||||
|
||||
const saveVoiceSettings = useCallback(
|
||||
async (settings: {
|
||||
auto_send?: boolean;
|
||||
auto_playback?: boolean;
|
||||
playback_speed?: number;
|
||||
}) => {
|
||||
try {
|
||||
await updateUserVoiceSettings(settings);
|
||||
toast.success("Preferences saved");
|
||||
} catch {
|
||||
toast.error("Failed to save preferences");
|
||||
}
|
||||
},
|
||||
[updateUserVoiceSettings]
|
||||
);
|
||||
|
||||
const commitVoicePlaybackSpeed = useCallback(() => {
|
||||
const currentSpeed = user?.preferences.voice_playback_speed ?? 1;
|
||||
if (Math.abs(currentSpeed - draftVoicePlaybackSpeed) < 0.001) {
|
||||
return;
|
||||
}
|
||||
void saveVoiceSettings({
|
||||
playback_speed: draftVoicePlaybackSpeed,
|
||||
});
|
||||
}, [
|
||||
draftVoicePlaybackSpeed,
|
||||
saveVoiceSettings,
|
||||
user?.preferences.voice_playback_speed,
|
||||
]);
|
||||
|
||||
// Wrapper to save memories and return success/failure
|
||||
const handleSaveMemories = useCallback(
|
||||
@@ -974,69 +936,6 @@ function ChatPreferencesSettings() {
|
||||
{user?.preferences?.shortcut_enabled && <PromptShortcuts />}
|
||||
</Card>
|
||||
</Section>
|
||||
|
||||
<Section gap={0.75}>
|
||||
<Content
|
||||
title="Voice"
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
widthVariant="full"
|
||||
/>
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Send"
|
||||
description="Automatically send voice input when recording stops."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_send ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void saveVoiceSettings({ auto_send: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Auto-Playback"
|
||||
description="Automatically play voice responses."
|
||||
>
|
||||
<Switch
|
||||
checked={user?.preferences.voice_auto_playback ?? false}
|
||||
onCheckedChange={(checked) => {
|
||||
void saveVoiceSettings({ auto_playback: checked });
|
||||
}}
|
||||
/>
|
||||
</InputLayouts.Horizontal>
|
||||
|
||||
<InputLayouts.Horizontal
|
||||
title="Playback Speed"
|
||||
description="Adjust the speed of voice playback."
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
<input
|
||||
type="range"
|
||||
min="0.5"
|
||||
max="2"
|
||||
step="0.1"
|
||||
value={draftVoicePlaybackSpeed}
|
||||
onChange={(e) => {
|
||||
setDraftVoicePlaybackSpeed(parseFloat(e.target.value));
|
||||
}}
|
||||
onMouseUp={commitVoicePlaybackSpeed}
|
||||
onTouchEnd={commitVoicePlaybackSpeed}
|
||||
onKeyUp={(e) => {
|
||||
if (e.key === "ArrowLeft" || e.key === "ArrowRight") {
|
||||
commitVoicePlaybackSpeed();
|
||||
}
|
||||
}}
|
||||
className="w-24 h-2 rounded-lg appearance-none cursor-pointer bg-background-neutral-02"
|
||||
/>
|
||||
<span className="text-sm text-text-02 w-10">
|
||||
{draftVoicePlaybackSpeed.toFixed(1)}x
|
||||
</span>
|
||||
</div>
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
</Section>
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import { Disabled } from "@opal/core";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import InputChipField from "@/refresh-components/inputs/InputChipField";
|
||||
import type { ChipItem } from "@/refresh-components/inputs/InputChipField";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { inviteUsers } from "./svc";
|
||||
|
||||
@@ -148,6 +149,11 @@ export default function InviteUsersModal({
|
||||
placeholder="Add emails to invite, comma separated"
|
||||
layout="stacked"
|
||||
/>
|
||||
{chips.some((c) => c.error) && (
|
||||
<Text secondaryBody className="text-status-warning-text pt-1">
|
||||
Some email addresses are invalid and will be skipped.
|
||||
</Text>
|
||||
)}
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
|
||||
@@ -166,6 +166,7 @@ export default function UserFilters({
|
||||
<Popover>
|
||||
<Popover.Trigger asChild>
|
||||
<FilterButton
|
||||
data-testid="filter-role"
|
||||
leftIcon={SvgUsers}
|
||||
active={hasRoleFilter}
|
||||
onClear={() => onRolesChange([])}
|
||||
@@ -213,6 +214,7 @@ export default function UserFilters({
|
||||
>
|
||||
<Popover.Trigger asChild>
|
||||
<FilterButton
|
||||
data-testid="filter-group"
|
||||
leftIcon={SvgUsers}
|
||||
active={hasGroupFilter}
|
||||
onClear={() => onGroupsChange([])}
|
||||
@@ -267,6 +269,7 @@ export default function UserFilters({
|
||||
<Popover>
|
||||
<Popover.Trigger asChild>
|
||||
<FilterButton
|
||||
data-testid="filter-status"
|
||||
leftIcon={SvgUsers}
|
||||
active={hasStatusFilter}
|
||||
onClear={() => onStatusesChange([])}
|
||||
|
||||
@@ -90,46 +90,45 @@ export default function UserRoleCell({ user, onMutate }: UserRoleCellProps) {
|
||||
const currentIcon = ROLE_ICONS[user.role] ?? SvgUser;
|
||||
|
||||
return (
|
||||
<div className="[&_button]:rounded-08">
|
||||
<Disabled disabled={isUpdating}>
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
icon={currentIcon}
|
||||
variant="select-tinted"
|
||||
width="full"
|
||||
justifyContent="between"
|
||||
>
|
||||
{USER_ROLE_LABELS[user.role]}
|
||||
</OpenButton>
|
||||
</Popover.Trigger>
|
||||
<Popover.Content align="start">
|
||||
<div className="flex flex-col gap-1 p-1 min-w-[160px]">
|
||||
{SELECTABLE_ROLES.map((role) => {
|
||||
if (
|
||||
role === UserRole.GLOBAL_CURATOR &&
|
||||
!isPaidEnterpriseFeaturesEnabled
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
const isSelected = user.role === role;
|
||||
const icon = ROLE_ICONS[role] ?? SvgUser;
|
||||
return (
|
||||
<LineItem
|
||||
key={role}
|
||||
icon={isSelected ? SvgCheck : icon}
|
||||
selected={isSelected}
|
||||
emphasized={isSelected}
|
||||
onClick={() => handleSelect(role)}
|
||||
>
|
||||
{USER_ROLE_LABELS[role]}
|
||||
</LineItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
</Disabled>
|
||||
</div>
|
||||
<Disabled disabled={isUpdating}>
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
icon={currentIcon}
|
||||
variant="select-tinted"
|
||||
width="full"
|
||||
justifyContent="between"
|
||||
roundingVariant="compact"
|
||||
>
|
||||
{USER_ROLE_LABELS[user.role]}
|
||||
</OpenButton>
|
||||
</Popover.Trigger>
|
||||
<Popover.Content align="start">
|
||||
<div className="flex flex-col gap-1 p-1 min-w-[160px]">
|
||||
{SELECTABLE_ROLES.map((role) => {
|
||||
if (
|
||||
role === UserRole.GLOBAL_CURATOR &&
|
||||
!isPaidEnterpriseFeaturesEnabled
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
const isSelected = user.role === role;
|
||||
const icon = ROLE_ICONS[role] ?? SvgUser;
|
||||
return (
|
||||
<LineItem
|
||||
key={role}
|
||||
icon={isSelected ? SvgCheck : icon}
|
||||
selected={isSelected}
|
||||
emphasized={isSelected}
|
||||
onClick={() => handleSelect(role)}
|
||||
>
|
||||
{USER_ROLE_LABELS[role]}
|
||||
</LineItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
</Disabled>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -208,7 +208,7 @@ export default function UserRowActions({
|
||||
);
|
||||
}}
|
||||
>
|
||||
Cancel
|
||||
Cancel Invite
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
|
||||
@@ -237,6 +237,7 @@ export default function UsersTable({
|
||||
prominence="tertiary"
|
||||
size="sm"
|
||||
tooltip="Download CSV"
|
||||
aria-label="Download CSV"
|
||||
onClick={() => {
|
||||
downloadUsersCsv().catch((err) => {
|
||||
toast.error(
|
||||
|
||||
@@ -43,7 +43,6 @@ import {
|
||||
SvgArrowUp,
|
||||
SvgGlobe,
|
||||
SvgHourglass,
|
||||
SvgMicrophone,
|
||||
SvgPlus,
|
||||
SvgPlusCircle,
|
||||
SvgSearch,
|
||||
@@ -56,10 +55,6 @@ import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
import MicrophoneButton from "@/sections/input/MicrophoneButton";
|
||||
import Waveform from "@/components/voice/Waveform";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { useVoiceStatus } from "@/hooks/useVoiceStatus";
|
||||
|
||||
const MIN_INPUT_HEIGHT = 44;
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
@@ -118,72 +113,16 @@ const AppInputBar = React.memo(
|
||||
}: AppInputBarProps) => {
|
||||
// Internal message state - kept local to avoid parent re-renders on every keystroke
|
||||
const [message, setMessage] = useState(initialMessage);
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [recordingCycleCount, setRecordingCycleCount] = useState(0);
|
||||
const [isMuted, setIsMuted] = useState(false);
|
||||
const [audioLevel, setAudioLevel] = useState(0);
|
||||
const stopRecordingRef = useRef<(() => Promise<string | null>) | null>(
|
||||
null
|
||||
);
|
||||
const setMutedRef = useRef<((muted: boolean) => void) | null>(null);
|
||||
const textAreaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const textAreaWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesWrapperRef = useRef<HTMLDivElement>(null);
|
||||
const filesContentRef = useRef<HTMLDivElement>(null);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const { user, isAdmin } = useUser();
|
||||
const { user } = useUser();
|
||||
const { state } = useQueryController();
|
||||
const isClassifying = state.phase === "classifying";
|
||||
const isSearchActive =
|
||||
state.phase === "searching" || state.phase === "search-results";
|
||||
const {
|
||||
stopTTS,
|
||||
isTTSPlaying,
|
||||
isManualTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
isTTSMuted,
|
||||
toggleTTSMute,
|
||||
} = useVoiceMode();
|
||||
const { sttEnabled } = useVoiceStatus();
|
||||
// Show mic button: always if STT configured, or greyed-out for admins to prompt setup
|
||||
const showMicButton = sttEnabled || isAdmin;
|
||||
const isVoicePlaybackActive =
|
||||
isTTSPlaying || isTTSLoading || isAwaitingAutoPlaybackStart;
|
||||
const isVoicePlaybackControllable = isVoicePlaybackActive && !isRecording;
|
||||
const isTTSActuallySpeaking = isTTSPlaying || isManualTTSPlaying;
|
||||
const appFocus = useAppFocus();
|
||||
const isNewSession = appFocus.isNewSession();
|
||||
const appMode = state.phase === "idle" ? state.appMode : undefined;
|
||||
const isSearchMode =
|
||||
(isNewSession && appMode === "search") || isSearchActive;
|
||||
|
||||
const handleRecordingChange = useCallback((nextIsRecording: boolean) => {
|
||||
setIsRecording((prevIsRecording) => {
|
||||
if (!prevIsRecording && nextIsRecording) {
|
||||
setRecordingCycleCount((count) => count + 1);
|
||||
}
|
||||
return nextIsRecording;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Wrapper for onSubmit that stops TTS first to prevent overlapping voices
|
||||
const handleSubmit = useCallback(
|
||||
(text: string) => {
|
||||
stopTTS();
|
||||
onSubmit(text);
|
||||
},
|
||||
[stopTTS, onSubmit]
|
||||
);
|
||||
const submitMessage = useCallback(
|
||||
(text: string) => {
|
||||
if (!text.trim()) {
|
||||
return;
|
||||
}
|
||||
handleSubmit(text);
|
||||
},
|
||||
[handleSubmit]
|
||||
);
|
||||
|
||||
// Expose reset and focus methods to parent via ref
|
||||
React.useImperativeHandle(ref, () => ({
|
||||
@@ -203,16 +142,10 @@ const AppInputBar = React.memo(
|
||||
setMessage(initialMessage);
|
||||
}
|
||||
}, [initialMessage]);
|
||||
const shouldShowRecordingWaveformBelow =
|
||||
isRecording &&
|
||||
!isVoicePlaybackActive &&
|
||||
(isNewSession || recordingCycleCount === 1);
|
||||
|
||||
useEffect(() => {
|
||||
if (isNewSession && !initialMessage) {
|
||||
setMessage("");
|
||||
}
|
||||
}, [isNewSession, initialMessage]);
|
||||
const appFocus = useAppFocus();
|
||||
const appMode = state.phase === "idle" ? state.appMode : undefined;
|
||||
const isSearchMode =
|
||||
(appFocus.isNewSession() && appMode === "search") || isSearchActive;
|
||||
|
||||
const { forcedToolIds, setForcedToolIds } = useForcedTools();
|
||||
const { currentMessageFiles, setCurrentMessageFiles, currentProjectId } =
|
||||
@@ -625,42 +558,9 @@ const AppInputBar = React.memo(
|
||||
disabled={disabled}
|
||||
/>
|
||||
</div>
|
||||
{showMicButton &&
|
||||
(sttEnabled ? (
|
||||
<MicrophoneButton
|
||||
onTranscription={(text) => setMessage(text)}
|
||||
disabled={disabled || chatState === "streaming"}
|
||||
autoSend={user?.preferences?.voice_auto_send ?? false}
|
||||
autoListen={user?.preferences?.voice_auto_playback ?? false}
|
||||
isNewSession={isNewSession}
|
||||
chatState={chatState}
|
||||
onRecordingChange={handleRecordingChange}
|
||||
stopRecordingRef={stopRecordingRef}
|
||||
currentMessage={message}
|
||||
onRecordingStart={() => {}}
|
||||
onAutoSend={(text) => {
|
||||
submitMessage(text);
|
||||
}}
|
||||
onMuteChange={setIsMuted}
|
||||
setMutedRef={setMutedRef}
|
||||
onAudioLevel={setAudioLevel}
|
||||
/>
|
||||
) : (
|
||||
<Disabled disabled>
|
||||
<Button
|
||||
icon={SvgMicrophone}
|
||||
aria-label="Set up voice"
|
||||
prominence="tertiary"
|
||||
tooltip="Voice not configured. Set up in admin settings."
|
||||
/>
|
||||
</Disabled>
|
||||
))}
|
||||
|
||||
<Disabled
|
||||
disabled={
|
||||
(chatState === "input" &&
|
||||
!isVoicePlaybackControllable &&
|
||||
!message) ||
|
||||
(chatState === "input" && !message) ||
|
||||
hasUploadingFiles ||
|
||||
isClassifying
|
||||
}
|
||||
@@ -670,18 +570,15 @@ const AppInputBar = React.memo(
|
||||
icon={
|
||||
isClassifying
|
||||
? SimpleLoader
|
||||
: chatState === "streaming" || isVoicePlaybackControllable
|
||||
? SvgStop
|
||||
: SvgArrowUp
|
||||
: chatState === "input"
|
||||
? SvgArrowUp
|
||||
: SvgStop
|
||||
}
|
||||
onClick={() => {
|
||||
if (chatState == "streaming") {
|
||||
stopTTS({ manual: true });
|
||||
stopGenerating();
|
||||
} else if (isVoicePlaybackControllable) {
|
||||
stopTTS({ manual: true });
|
||||
} else if (message) {
|
||||
submitMessage(message);
|
||||
onSubmit(message);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@@ -696,7 +593,7 @@ const AppInputBar = React.memo(
|
||||
ref={containerRef}
|
||||
id="onyx-chat-input"
|
||||
className={cn(
|
||||
"relative w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16"
|
||||
"w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16"
|
||||
// # Note (from @raunakab):
|
||||
//
|
||||
// `shadow-01` extends ~14px below the element (2px offset + 12px blur).
|
||||
@@ -709,32 +606,6 @@ const AppInputBar = React.memo(
|
||||
// modes. See the corresponding note there for details.
|
||||
)}
|
||||
>
|
||||
{/* Voice waveform overlay (positioned outside normal flow to avoid resizing input) */}
|
||||
{isTTSActuallySpeaking ? (
|
||||
<div className="absolute bottom-full mb-1 left-1 z-10">
|
||||
<Waveform
|
||||
variant="speaking"
|
||||
isActive={isTTSActuallySpeaking}
|
||||
isMuted={isTTSMuted}
|
||||
onMuteToggle={toggleTTSMute}
|
||||
/>
|
||||
</div>
|
||||
) : isRecording &&
|
||||
!isVoicePlaybackActive &&
|
||||
!shouldShowRecordingWaveformBelow ? (
|
||||
<div className="absolute bottom-full mb-1 left-1 right-1 z-10">
|
||||
<Waveform
|
||||
variant="recording"
|
||||
isActive={isRecording}
|
||||
isMuted={isMuted}
|
||||
audioLevel={audioLevel}
|
||||
onMuteToggle={() => {
|
||||
setMutedRef.current?.(!isMuted);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
|
||||
{/* Attached Files */}
|
||||
<div
|
||||
ref={filesWrapperRef}
|
||||
@@ -786,13 +657,9 @@ const AppInputBar = React.memo(
|
||||
style={{ scrollbarWidth: "thin" }}
|
||||
aria-multiline={true}
|
||||
placeholder={
|
||||
isRecording
|
||||
? "Listening..."
|
||||
: isVoicePlaybackActive
|
||||
? "Onyx is speaking..."
|
||||
: isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
isSearchMode
|
||||
? "Search connected sources"
|
||||
: "How can I help you today?"
|
||||
}
|
||||
value={message}
|
||||
onKeyDown={(event) => {
|
||||
@@ -809,7 +676,7 @@ const AppInputBar = React.memo(
|
||||
!isClassifying &&
|
||||
!hasUploadingFiles
|
||||
) {
|
||||
submitMessage(message);
|
||||
onSubmit(message);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -876,7 +743,7 @@ const AppInputBar = React.memo(
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
submitMessage(message);
|
||||
onSubmit(message);
|
||||
}
|
||||
}}
|
||||
prominence="tertiary"
|
||||
@@ -888,21 +755,6 @@ const AppInputBar = React.memo(
|
||||
</div>
|
||||
|
||||
{chatControls}
|
||||
|
||||
{/* First recording cycle waveform below input */}
|
||||
{shouldShowRecordingWaveformBelow && (
|
||||
<div className="absolute top-full mt-1 left-1 right-1 z-10">
|
||||
<Waveform
|
||||
variant="recording"
|
||||
isActive={isRecording}
|
||||
isMuted={isMuted}
|
||||
audioLevel={audioLevel}
|
||||
onMuteToggle={() => {
|
||||
setMutedRef.current?.(!isMuted);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Disabled>
|
||||
);
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useRef } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgMicrophone } from "@opal/icons";
|
||||
import { useVoiceRecorder } from "@/hooks/useVoiceRecorder";
|
||||
import { useVoiceMode } from "@/providers/VoiceModeProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { ChatState } from "@/app/app/interfaces";
|
||||
|
||||
interface MicrophoneButtonProps {
|
||||
onTranscription: (text: string) => void;
|
||||
disabled?: boolean;
|
||||
autoSend?: boolean;
|
||||
/** Called with transcribed text when autoSend is enabled */
|
||||
onAutoSend?: (text: string) => void;
|
||||
/**
|
||||
* Internal prop: auto-start listening when TTS finishes or chat response completes.
|
||||
* Tied to voice_auto_playback user preference.
|
||||
* Enables conversation flow: speak → AI responds → auto-listen again.
|
||||
* Note: autoSend is separate - it controls whether message auto-submits after recording.
|
||||
*/
|
||||
autoListen?: boolean;
|
||||
/** Current chat state - used to detect when response streaming finishes */
|
||||
chatState?: ChatState;
|
||||
/** Called when recording state changes */
|
||||
onRecordingChange?: (isRecording: boolean) => void;
|
||||
/** Ref to expose stop recording function to parent */
|
||||
stopRecordingRef?: React.MutableRefObject<
|
||||
(() => Promise<string | null>) | null
|
||||
>;
|
||||
/** Called when recording starts */
|
||||
onRecordingStart?: () => void;
|
||||
/** Existing message text to prepend to transcription (append mode) */
|
||||
currentMessage?: string;
|
||||
/** Called when mute state changes */
|
||||
onMuteChange?: (isMuted: boolean) => void;
|
||||
/** Ref to expose setMuted function to parent */
|
||||
setMutedRef?: React.MutableRefObject<((muted: boolean) => void) | null>;
|
||||
/** Called with current microphone audio level (0-1) for waveform visualization */
|
||||
onAudioLevel?: (level: number) => void;
|
||||
/** Whether current chat is a new session (used to reset auto-listen arming) */
|
||||
isNewSession?: boolean;
|
||||
}
|
||||
|
||||
function MicrophoneButton({
|
||||
onTranscription,
|
||||
disabled = false,
|
||||
autoSend = false,
|
||||
onAutoSend,
|
||||
autoListen = false,
|
||||
chatState,
|
||||
onRecordingChange,
|
||||
stopRecordingRef,
|
||||
onRecordingStart,
|
||||
currentMessage = "",
|
||||
onMuteChange,
|
||||
setMutedRef,
|
||||
onAudioLevel,
|
||||
isNewSession = false,
|
||||
}: MicrophoneButtonProps) {
|
||||
const {
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
manualStopCount,
|
||||
} = useVoiceMode();
|
||||
|
||||
// Refs for tracking state across renders
|
||||
// Track whether TTS was actually playing audio (not just loading)
|
||||
const wasTTSActuallyPlayingRef = useRef(false);
|
||||
const manualStopRequestedRef = useRef(false);
|
||||
const lastHandledManualStopCountRef = useRef(manualStopCount);
|
||||
const autoListenCooldownTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const hasManualRecordStartRef = useRef(false);
|
||||
// Prevent late transcript events from repopulating input after auto-send.
|
||||
const suppressTranscriptUpdatesRef = useRef(false);
|
||||
// Snapshot of existing message text when recording starts (for append mode)
|
||||
const messagePrefixRef = useRef("");
|
||||
const currentMessageRef = useRef(currentMessage);
|
||||
|
||||
useEffect(() => {
|
||||
currentMessageRef.current = currentMessage;
|
||||
}, [currentMessage]);
|
||||
|
||||
// Helper to combine prefix with new transcript
|
||||
const withPrefix = useCallback((text: string) => {
|
||||
const prefix = messagePrefixRef.current;
|
||||
if (!prefix) return text;
|
||||
return prefix + (prefix.endsWith(" ") ? "" : " ") + text;
|
||||
}, []);
|
||||
|
||||
// Handler for VAD (Voice Activity Detection) triggered auto-send.
|
||||
// VAD runs server-side in the STT provider and detects when the user stops speaking.
|
||||
const handleFinalTranscript = useCallback(
|
||||
(text: string) => {
|
||||
const combined = withPrefix(text);
|
||||
if (!suppressTranscriptUpdatesRef.current) {
|
||||
onTranscription(combined);
|
||||
}
|
||||
const isManualStop = manualStopRequestedRef.current;
|
||||
// Only auto-send if chat is ready for input (not streaming)
|
||||
if (!isManualStop && autoSend && onAutoSend && chatState === "input") {
|
||||
suppressTranscriptUpdatesRef.current = true;
|
||||
onAutoSend(combined);
|
||||
// Clear prefix after send to prevent stale text in next auto-listen cycle
|
||||
messagePrefixRef.current = "";
|
||||
}
|
||||
},
|
||||
[onTranscription, autoSend, onAutoSend, chatState, withPrefix]
|
||||
);
|
||||
|
||||
const {
|
||||
isRecording,
|
||||
isProcessing,
|
||||
isMuted,
|
||||
error,
|
||||
liveTranscript,
|
||||
audioLevel,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
setMuted,
|
||||
} = useVoiceRecorder({ onFinalTranscript: handleFinalTranscript });
|
||||
|
||||
// Expose stopRecording to parent
|
||||
useEffect(() => {
|
||||
if (stopRecordingRef) {
|
||||
stopRecordingRef.current = stopRecording;
|
||||
}
|
||||
}, [stopRecording, stopRecordingRef]);
|
||||
|
||||
// Expose setMuted to parent
|
||||
useEffect(() => {
|
||||
if (setMutedRef) {
|
||||
setMutedRef.current = setMuted;
|
||||
}
|
||||
}, [setMuted, setMutedRef]);
|
||||
|
||||
// Notify parent when mute state changes
|
||||
useEffect(() => {
|
||||
onMuteChange?.(isMuted);
|
||||
}, [isMuted, onMuteChange]);
|
||||
|
||||
// Forward audio level to parent for waveform visualization
|
||||
useEffect(() => {
|
||||
onAudioLevel?.(audioLevel);
|
||||
}, [audioLevel, onAudioLevel]);
|
||||
|
||||
// Notify parent when recording state changes
|
||||
useEffect(() => {
|
||||
onRecordingChange?.(isRecording);
|
||||
}, [isRecording, onRecordingChange]);
|
||||
|
||||
// Update input with live transcript as user speaks (appending to existing text)
|
||||
useEffect(() => {
|
||||
if (
|
||||
isRecording &&
|
||||
liveTranscript &&
|
||||
!suppressTranscriptUpdatesRef.current
|
||||
) {
|
||||
onTranscription(withPrefix(liveTranscript));
|
||||
}
|
||||
}, [isRecording, liveTranscript, onTranscription, withPrefix]);
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (isRecording) {
|
||||
// When recording, clicking the mic button stops recording
|
||||
manualStopRequestedRef.current = true;
|
||||
try {
|
||||
const finalTranscript = await stopRecording();
|
||||
if (finalTranscript) {
|
||||
const combined = withPrefix(finalTranscript);
|
||||
onTranscription(combined);
|
||||
if (
|
||||
autoSend &&
|
||||
onAutoSend &&
|
||||
chatState === "input" &&
|
||||
combined.trim()
|
||||
) {
|
||||
onAutoSend(combined);
|
||||
}
|
||||
}
|
||||
messagePrefixRef.current = "";
|
||||
} finally {
|
||||
manualStopRequestedRef.current = false;
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
// Snapshot existing text so transcription can append to it
|
||||
suppressTranscriptUpdatesRef.current = false;
|
||||
messagePrefixRef.current = currentMessage;
|
||||
onRecordingStart?.();
|
||||
await startRecording();
|
||||
// Arm auto-listen only after first manual mic start in this session.
|
||||
hasManualRecordStartRef.current = true;
|
||||
} catch (err) {
|
||||
console.error("Microphone access failed:", err);
|
||||
toast.error("Could not access microphone");
|
||||
}
|
||||
}
|
||||
}, [
|
||||
isRecording,
|
||||
startRecording,
|
||||
stopRecording,
|
||||
onRecordingStart,
|
||||
onTranscription,
|
||||
autoSend,
|
||||
onAutoSend,
|
||||
chatState,
|
||||
currentMessage,
|
||||
withPrefix,
|
||||
]);
|
||||
|
||||
// Auto-start listening shortly after TTS finishes (only if autoListen is enabled).
|
||||
// Small cooldown reduces playback bleed being re-captured by the microphone.
|
||||
// IMPORTANT: Only trigger auto-listen if TTS was actually playing audio,
|
||||
// not just loading. This prevents auto-listen from triggering when TTS fails.
|
||||
useEffect(() => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
|
||||
const stoppedManually =
|
||||
manualStopCount !== lastHandledManualStopCountRef.current;
|
||||
|
||||
// Only trigger auto-listen if TTS was actually playing (not just loading)
|
||||
if (
|
||||
wasTTSActuallyPlayingRef.current &&
|
||||
!isTTSPlaying &&
|
||||
!isTTSLoading &&
|
||||
!isAwaitingAutoPlaybackStart &&
|
||||
autoListen &&
|
||||
hasManualRecordStartRef.current &&
|
||||
!disabled &&
|
||||
!isRecording &&
|
||||
!stoppedManually
|
||||
) {
|
||||
autoListenCooldownTimerRef.current = setTimeout(() => {
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
if (
|
||||
!autoListen ||
|
||||
disabled ||
|
||||
isRecording ||
|
||||
isTTSPlaying ||
|
||||
isTTSLoading ||
|
||||
isAwaitingAutoPlaybackStart
|
||||
) {
|
||||
return;
|
||||
}
|
||||
messagePrefixRef.current = currentMessageRef.current;
|
||||
startRecording().catch((err) => {
|
||||
console.error("Auto-start microphone failed:", err);
|
||||
toast.error("Could not auto-start microphone");
|
||||
});
|
||||
}, 400);
|
||||
}
|
||||
|
||||
if (stoppedManually) {
|
||||
lastHandledManualStopCountRef.current = manualStopCount;
|
||||
}
|
||||
|
||||
// Only track actual playback - not loading states
|
||||
// This ensures auto-listen only triggers after audio actually played
|
||||
if (isTTSPlaying) {
|
||||
wasTTSActuallyPlayingRef.current = true;
|
||||
} else if (!isTTSPlaying && !isTTSLoading && !isAwaitingAutoPlaybackStart) {
|
||||
// Reset when TTS is completely done
|
||||
wasTTSActuallyPlayingRef.current = false;
|
||||
}
|
||||
}, [
|
||||
isTTSPlaying,
|
||||
isTTSLoading,
|
||||
isAwaitingAutoPlaybackStart,
|
||||
autoListen,
|
||||
disabled,
|
||||
isRecording,
|
||||
startRecording,
|
||||
manualStopCount,
|
||||
]);
|
||||
|
||||
// New sessions must start with an explicit manual mic press.
|
||||
useEffect(() => {
|
||||
if (isNewSession) {
|
||||
hasManualRecordStartRef.current = false;
|
||||
suppressTranscriptUpdatesRef.current = false;
|
||||
}
|
||||
}, [isNewSession]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isRecording) {
|
||||
suppressTranscriptUpdatesRef.current = false;
|
||||
}
|
||||
}, [isRecording]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (autoListenCooldownTimerRef.current) {
|
||||
clearTimeout(autoListenCooldownTimerRef.current);
|
||||
autoListenCooldownTimerRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (error) {
|
||||
console.error("Voice recorder error:", error);
|
||||
toast.error(error);
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
// Icon: show loader when processing, otherwise mic
|
||||
const icon = isProcessing ? SimpleLoader : SvgMicrophone;
|
||||
|
||||
// Disable when processing or TTS is playing (don't want to pick up TTS audio)
|
||||
const isDisabled =
|
||||
disabled ||
|
||||
isProcessing ||
|
||||
isTTSPlaying ||
|
||||
isTTSLoading ||
|
||||
isAwaitingAutoPlaybackStart;
|
||||
|
||||
// Recording = darkened (primary), not recording = light (tertiary)
|
||||
const prominence = isRecording ? "primary" : "tertiary";
|
||||
|
||||
return (
|
||||
<Disabled disabled={isDisabled}>
|
||||
<Button
|
||||
icon={icon}
|
||||
onClick={handleClick}
|
||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||
prominence={prominence}
|
||||
/>
|
||||
</Disabled>
|
||||
);
|
||||
}
|
||||
|
||||
export default MicrophoneButton;
|
||||
@@ -1,5 +1,4 @@
|
||||
import {
|
||||
LLMProviderName,
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
@@ -13,11 +12,6 @@ import { toast } from "@/hooks/useToast";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { ScopedMutator } from "swr";
|
||||
import {
|
||||
track,
|
||||
AnalyticsEvent,
|
||||
LLMProviderConfiguredSource,
|
||||
} from "@/lib/analytics";
|
||||
|
||||
// Common class names for the Form component across all LLM provider forms
|
||||
export const LLM_FORM_CLASS_NAME = "flex flex-col gap-y-4 items-stretch mt-6";
|
||||
@@ -305,12 +299,5 @@ export const submitLLMProvider = async <T extends BaseLLMFormValues>({
|
||||
toast.success(successMsg);
|
||||
}
|
||||
|
||||
const knownProviders = new Set<string>(Object.values(LLMProviderName));
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: knownProviders.has(providerName) ? providerName : "custom",
|
||||
is_creation: !existingLlmProvider,
|
||||
source: LLMProviderConfiguredSource.ADMIN_PAGE,
|
||||
});
|
||||
|
||||
setSubmitting(false);
|
||||
};
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useMemo, ReactNode } from "react";
|
||||
import {
|
||||
track,
|
||||
AnalyticsEvent,
|
||||
LLMProviderConfiguredSource,
|
||||
} from "@/lib/analytics";
|
||||
import { Form, Formik, FormikProps } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import ProviderModal from "@/components/modals/ProviderModal";
|
||||
@@ -269,12 +264,6 @@ export function OnboardingFormWrapper<T extends Record<string, any>>({
|
||||
}
|
||||
}
|
||||
|
||||
track(AnalyticsEvent.CONFIGURED_LLM_PROVIDER, {
|
||||
provider: isCustomProvider ? "custom" : llmDescriptor?.name ?? "",
|
||||
is_creation: true,
|
||||
source: LLMProviderConfiguredSource.CHAT_ONBOARDING,
|
||||
});
|
||||
|
||||
// Update onboarding state
|
||||
onboardingActions?.updateData({
|
||||
llmProviders: [
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user