mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-06 16:15:46 +00:00
Compare commits
51 Commits
nikg/std-e
...
voice-mode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fead4dfbf | ||
|
|
ac4b49a7f9 | ||
|
|
fc22232f14 | ||
|
|
9ddd44bf56 | ||
|
|
8587911cf6 | ||
|
|
d7300d50d7 | ||
|
|
cc950a2da2 | ||
|
|
8d6640159a | ||
|
|
bba77749c3 | ||
|
|
3e9a66c8ff | ||
|
|
548b9d9e0e | ||
|
|
0d3967baee | ||
|
|
6ed806eebb | ||
|
|
3b6a35b2c4 | ||
|
|
62e612f85f | ||
|
|
7eabfa125c | ||
|
|
ee18114739 | ||
|
|
b375b7f0ff | ||
|
|
f7630f5648 | ||
|
|
e0d91b9ea7 | ||
|
|
2c0a4a60a5 | ||
|
|
3a7d4dad56 | ||
|
|
c5c236d098 | ||
|
|
b18baff4d0 | ||
|
|
eb3e15c195 | ||
|
|
47d9a9e1ac | ||
|
|
aca466b35d | ||
|
|
c158ae2622 | ||
|
|
5176fd7386 | ||
|
|
698494626f | ||
|
|
92538084e9 | ||
|
|
2d996e05a4 | ||
|
|
b2956f795b | ||
|
|
b272085543 | ||
|
|
8193aa4fd0 | ||
|
|
52db41a00b | ||
|
|
f1cf3c4589 | ||
|
|
5322aeed90 | ||
|
|
5da8870fd2 | ||
|
|
57d3ab3b40 | ||
|
|
649c7fe8b9 | ||
|
|
e5e2bc6149 | ||
|
|
b148065e1d | ||
|
|
367808951c | ||
|
|
93cefe7ef0 | ||
|
|
8a326c4089 | ||
|
|
0c5410f429 | ||
|
|
0b05b9b235 | ||
|
|
59d8a988bd | ||
|
|
6d08cfb25a | ||
|
|
53a5ee2a6e |
1
.github/workflows/pr-integration-tests.yml
vendored
1
.github/workflows/pr-integration-tests.yml
vendored
@@ -335,7 +335,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
|
||||
108
.github/workflows/pr-playwright-tests.yml
vendored
108
.github/workflows/pr-playwright-tests.yml
vendored
@@ -268,10 +268,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -279,6 +280,7 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -590,6 +592,108 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -686,7 +790,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests]
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -40,19 +40,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"name": "Celery",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -253,35 +241,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -526,21 +485,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
31
AGENTS.md
31
AGENTS.md
@@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX ix_voice_provider_one_default_stt
|
||||
ON voice_provider (is_default_stt)
|
||||
WHERE is_default_stt = true
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX ix_voice_provider_one_default_tts
|
||||
ON voice_provider (is_default_tts)
|
||||
WHERE is_default_tts = true
|
||||
"""
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("preferred_voice", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "preferred_voice")
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_tts")
|
||||
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_stt")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -1,15 +0,0 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -471,7 +472,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -774,8 +777,7 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -153,12 +152,9 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -223,6 +223,15 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -230,6 +239,7 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -365,6 +365,7 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -153,3 +156,8 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
@@ -28,6 +28,7 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -1599,6 +1600,91 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
_websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail=f"Authentication verification failed: {str(e)}"
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
logger.info(f"WS auth: user found: {user.email}")
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
logger.info(f"WS auth: user verified: {user.email}, role={user.role}")
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.info(f"WS auth: authentication successful for {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -39,9 +39,13 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
|
||||
doc_ids: set[str]
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -93,30 +97,37 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
ids: dict[str, str | None] = {}
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
ids.add(item.raw_node_id)
|
||||
if item.raw_node_id not in ids:
|
||||
ids[item.raw_node_id] = None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids.add(failed_id)
|
||||
ids[failed_id] = None
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
|
||||
ids[item.id] = parent_raw
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -132,7 +143,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -177,15 +188,20 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
for k, v in batch_ids.items():
|
||||
if v is not None or k not in all_raw_id_to_parent:
|
||||
all_raw_id_to_parent[k] = v
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
doc_ids=all_connector_doc_ids,
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -29,6 +29,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -47,6 +48,8 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
@@ -57,6 +60,8 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import ensure_source_node_exists
|
||||
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
|
||||
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
@@ -113,6 +118,38 @@ class PruneCallback(IndexingCallbackBase):
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
def _resolve_and_update_document_parents(
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
source: DocumentSource,
|
||||
raw_id_to_parent: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
|
||||
each document and bulk-update the DB. Mirrors the resolution logic in
|
||||
run_docfetching.py."""
|
||||
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
|
||||
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent_id in raw_id_to_parent.items():
|
||||
if raw_parent_id is None:
|
||||
continue
|
||||
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
|
||||
resolved[doc_id] = node_id if found else source_node_id
|
||||
|
||||
if not resolved:
|
||||
return
|
||||
|
||||
update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
task_logger.info(
|
||||
f"Pruning: resolved and updated parent hierarchy for "
|
||||
f"{len(resolved)} documents (source={source.value})"
|
||||
)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -535,22 +572,22 @@ def connector_pruning_generator_task(
|
||||
extraction_result = extract_ids_from_runnable_connector(
|
||||
runnable_connector, callback
|
||||
)
|
||||
all_connector_doc_ids = extraction_result.doc_ids
|
||||
all_connector_doc_ids = extraction_result.raw_id_to_parent
|
||||
|
||||
# Process hierarchy nodes (same as docfetching):
|
||||
# upsert to Postgres and cache in Redis
|
||||
source = cc_pair.connector.source
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if extraction_result.hierarchy_nodes:
|
||||
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
ensure_source_node_exists(
|
||||
redis_client, db_session, cc_pair.connector.source
|
||||
)
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=extraction_result.hierarchy_nodes,
|
||||
source=cc_pair.connector.source,
|
||||
source=source,
|
||||
commit=True,
|
||||
is_connector_public=is_connector_public,
|
||||
)
|
||||
@@ -561,7 +598,7 @@ def connector_pruning_generator_task(
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=cc_pair.connector.source,
|
||||
source=source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
@@ -570,6 +607,26 @@ def connector_pruning_generator_task(
|
||||
f"hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
ensure_source_node_exists(redis_client, db_session, source)
|
||||
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
|
||||
# and bulk-update documents, mirroring the docfetching resolution
|
||||
_resolve_and_update_document_parents(
|
||||
db_session=db_session,
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
raw_id_to_parent=all_connector_doc_ids,
|
||||
)
|
||||
|
||||
# Link hierarchy nodes to documents for sources where pages can be
|
||||
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
|
||||
all_doc_id_list = list(all_connector_doc_ids.keys())
|
||||
link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=all_doc_id_list,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
@@ -581,7 +638,9 @@ def connector_pruning_generator_task(
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
doc_ids_to_remove = list(
|
||||
all_indexed_document_ids - all_connector_doc_ids.keys()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
"Pruning set collected: "
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
app: Celery = fetch_versioned_implementation(
|
||||
"onyx.background.celery.apps.background",
|
||||
"celery_app",
|
||||
)
|
||||
@@ -495,14 +495,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
|
||||
)
|
||||
|
||||
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
|
||||
# Individual worker concurrency settings
|
||||
CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
@@ -84,7 +84,6 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
|
||||
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
|
||||
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
|
||||
@@ -943,6 +943,9 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
|
||||
page
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -992,6 +995,7 @@ class ConfluenceConnector(
|
||||
if include_permissions
|
||||
else None
|
||||
),
|
||||
parent_hierarchy_raw_node_id=page_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -781,4 +781,5 @@ def build_slim_document(
|
||||
return SlimDocument(
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
|
||||
)
|
||||
|
||||
@@ -902,6 +902,11 @@ class JiraConnector(
|
||||
external_access=self._get_project_permissions(
|
||||
project_key, add_prefix=False
|
||||
),
|
||||
parent_hierarchy_raw_node_id=(
|
||||
self._get_parent_hierarchy_raw_node_id(issue, project_key)
|
||||
if project_key
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
|
||||
@@ -385,6 +385,7 @@ class IndexingDocument(Document):
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
external_access: ExternalAccess | None = None
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
|
||||
@@ -772,6 +772,7 @@ def _convert_driveitem_to_slim_document(
|
||||
drive_name: str,
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
if driveitem.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
@@ -787,11 +788,15 @@ def _convert_driveitem_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=driveitem.id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
def _convert_sitepage_to_slim_document(
|
||||
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
|
||||
site_page: dict[str, Any],
|
||||
ctx: ClientContext | None,
|
||||
graph_client: GraphClient,
|
||||
parent_hierarchy_raw_node_id: str | None = None,
|
||||
) -> SlimDocument:
|
||||
"""Convert a SharePoint site page to a SlimDocument object."""
|
||||
if site_page.get("id") is None:
|
||||
@@ -808,6 +813,7 @@ def _convert_sitepage_to_slim_document(
|
||||
return SlimDocument(
|
||||
id=id,
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1594,12 +1600,22 @@ class SharepointConnector(
|
||||
)
|
||||
)
|
||||
|
||||
parent_hierarchy_url: str | None = None
|
||||
if drive_web_url:
|
||||
parent_hierarchy_url = self._get_parent_hierarchy_url(
|
||||
site_url, drive_web_url, drive_name, driveitem
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
driveitem,
|
||||
drive_name,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_url,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1619,7 +1635,10 @@ class SharepointConnector(
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page, ctx, self.graph_client
|
||||
site_page,
|
||||
ctx,
|
||||
self.graph_client,
|
||||
parent_hierarchy_raw_node_id=site_descriptor.url,
|
||||
)
|
||||
)
|
||||
if len(doc_batch) >= SLIM_BATCH_SIZE:
|
||||
|
||||
@@ -565,6 +565,7 @@ def _get_all_doc_ids(
|
||||
channel_id=channel_id, thread_ts=message["ts"]
|
||||
),
|
||||
external_access=external_access,
|
||||
parent_hierarchy_raw_node_id=channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -246,6 +247,7 @@ def insert_document_set(
|
||||
description=document_set_creation_request.description,
|
||||
user_id=user_id,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
time_last_modified_by_user=func.now(),
|
||||
)
|
||||
db_session.add(new_document_set_row)
|
||||
@@ -336,7 +338,8 @@ def update_document_set(
|
||||
)
|
||||
|
||||
document_set_row.description = document_set_update_request.description
|
||||
document_set_row.is_up_to_date = False
|
||||
if not DISABLE_VECTOR_DB:
|
||||
document_set_row.is_up_to_date = False
|
||||
document_set_row.is_public = document_set_update_request.is_public
|
||||
document_set_row.time_last_modified_by_user = func.now()
|
||||
versioned_private_doc_set_fn = fetch_versioned_implementation(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""CRUD operations for HierarchyNode."""
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -525,6 +527,53 @@ def get_document_parent_hierarchy_node_ids(
|
||||
return {doc_id: parent_id for doc_id, parent_id in results}
|
||||
|
||||
|
||||
def update_document_parent_hierarchy_nodes(
|
||||
db_session: Session,
|
||||
doc_parent_map: dict[str, int | None],
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
|
||||
|
||||
Only updates rows whose current value differs from the desired value to
|
||||
avoid unnecessary writes.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
|
||||
commit: Whether to commit the transaction
|
||||
|
||||
Returns:
|
||||
Number of documents actually updated
|
||||
"""
|
||||
if not doc_parent_map:
|
||||
return 0
|
||||
|
||||
doc_ids = list(doc_parent_map.keys())
|
||||
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
|
||||
|
||||
by_parent: dict[int | None, list[str]] = defaultdict(list)
|
||||
for doc_id, desired_parent_id in doc_parent_map.items():
|
||||
current = existing.get(doc_id)
|
||||
if current == desired_parent_id or doc_id not in existing:
|
||||
continue
|
||||
by_parent[desired_parent_id].append(doc_id)
|
||||
|
||||
updated = 0
|
||||
for desired_parent_id, ids in by_parent.items():
|
||||
db_session.query(Document).filter(Document.id.in_(ids)).update(
|
||||
{Document.parent_hierarchy_node_id: desired_parent_id},
|
||||
synchronize_session=False,
|
||||
)
|
||||
updated += len(ids)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
elif updated:
|
||||
db_session.flush()
|
||||
|
||||
return updated
|
||||
|
||||
|
||||
def update_hierarchy_node_permissions(
|
||||
db_session: Session,
|
||||
raw_node_id: str,
|
||||
|
||||
@@ -284,6 +284,12 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
preferred_voice: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -2964,6 +2970,63 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
raise RuntimeError("No search settings specified; DB is not in a valid state.")
|
||||
return latest_settings
|
||||
|
||||
|
||||
|
||||
224
backend/onyx/db/voice.py
Normal file
224
backend/onyx/db/voice.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
|
||||
# Sentinel value to distinguish "not provided" from "explicitly set to None"
|
||||
_UNSET: Any = object()
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
preferred_voice: str | None = _UNSET,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For most fields, None means "don't update this field".
|
||||
For preferred_voice, use None to clear the preference (reset to default),
|
||||
or omit the parameter to leave it unchanged.
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(0.5, min(2.0, playback_speed))
|
||||
# preferred_voice uses sentinel: _UNSET means "don't update", None means "clear"
|
||||
if preferred_voice is not _UNSET:
|
||||
values["preferred_voice"] = preferred_voice
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -32,9 +32,6 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
|
||||
@@ -26,11 +26,10 @@ def get_default_document_index(
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated, updates are applied to both indices.
|
||||
need to be updated. Updates are applied to both indices.
|
||||
WARNING: In that case, get_all_document_indices should be used.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
@@ -51,11 +50,26 @@ def get_default_document_index(
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=secondary_index_name,
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -86,8 +100,7 @@ def get_all_document_indices(
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
Large chunks are not currently supported so we hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
@@ -123,13 +136,36 @@ def get_all_document_indices(
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
secondary_indexing_setting = (
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
embedding_dim=indexing_setting.final_embedding_dim,
|
||||
embedding_precision=indexing_setting.embedding_precision,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
secondary_embedding_dim=(
|
||||
secondary_indexing_setting.final_embedding_dim
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
secondary_embedding_precision=(
|
||||
secondary_indexing_setting.embedding_precision
|
||||
if secondary_indexing_setting
|
||||
else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
@@ -271,6 +271,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_name: str | None,
|
||||
secondary_embedding_dim: int | None,
|
||||
secondary_embedding_precision: EmbeddingPrecision | None,
|
||||
# NOTE: We do not support large chunks right now.
|
||||
large_chunks_enabled: bool, # noqa: ARG002
|
||||
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
|
||||
multitenant: bool = False,
|
||||
@@ -286,12 +289,25 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
f"Expected {MULTI_TENANT}, got {multitenant}."
|
||||
)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
|
||||
tenant_state=tenant_state,
|
||||
index_name=index_name,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_precision=embedding_precision,
|
||||
)
|
||||
self._secondary_real_index: OpenSearchDocumentIndex | None = None
|
||||
if self.secondary_index_name:
|
||||
if secondary_embedding_dim is None or secondary_embedding_precision is None:
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
self._secondary_real_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=self.secondary_index_name,
|
||||
embedding_dim=secondary_embedding_dim,
|
||||
embedding_precision=secondary_embedding_precision,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
@@ -307,19 +323,38 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
self,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
# Only handle primary index for now, ignore secondary.
|
||||
return self._real_index.verify_and_create_index_if_necessary(
|
||||
self._real_index.verify_and_create_index_if_necessary(
|
||||
primary_embedding_dim, primary_embedding_precision
|
||||
)
|
||||
if self.secondary_index_name:
|
||||
if (
|
||||
secondary_index_embedding_dim is None
|
||||
or secondary_index_embedding_precision is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Bug: Secondary index embedding dimension and precision are not set."
|
||||
)
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.verify_and_create_index_if_necessary(
|
||||
secondary_index_embedding_dim, secondary_index_embedding_precision
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
# Convert IndexBatchParams to IndexingMetadata.
|
||||
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
|
||||
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
|
||||
@@ -351,7 +386,20 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self._real_index.delete(doc_id, chunk_count)
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
total_chunks_deleted += self._secondary_real_index.delete(
|
||||
doc_id, chunk_count
|
||||
)
|
||||
return total_chunks_deleted
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -362,6 +410,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> None:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
f"Tried to update document {doc_id} with no updated fields or user fields."
|
||||
@@ -392,6 +445,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
try:
|
||||
self._real_index.update([update_request])
|
||||
if self.secondary_index_name:
|
||||
assert (
|
||||
self._secondary_real_index is not None
|
||||
), "Bug: Secondary index is not initialized."
|
||||
self._secondary_real_index.update([update_request])
|
||||
except NotFoundError:
|
||||
logger.exception(
|
||||
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "
|
||||
|
||||
@@ -465,6 +465,12 @@ class VespaIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
NOTE: Do NOT consider the secondary index here. A separate indexing
|
||||
pipeline will be responsible for indexing to the secondary index. This
|
||||
design is not ideal and we should reconsider this when revamping index
|
||||
swapping.
|
||||
"""
|
||||
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
|
||||
index_batch_params.doc_id_to_new_chunk_cnt
|
||||
):
|
||||
@@ -659,6 +665,10 @@ class VespaIndex(DocumentIndex):
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for updating chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
if fields is None and user_fields is None:
|
||||
logger.warning(
|
||||
@@ -679,13 +689,6 @@ class VespaIndex(DocumentIndex):
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
@@ -705,7 +708,20 @@ class VespaIndex(DocumentIndex):
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
@@ -714,6 +730,11 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
NOTE: Remember to handle the secondary index here. There is no separate
|
||||
pipeline for deleting chunks in the secondary index. This design is not
|
||||
ideal and we should reconsider this when revamping index swapping.
|
||||
"""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
@@ -726,13 +747,25 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
|
||||
indices = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
indices.append(self.secondary_index_name)
|
||||
|
||||
total_chunks_deleted = 0
|
||||
for index_name in indices:
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
|
||||
index_name, False
|
||||
),
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
total_chunks_deleted += vespa_document_index.delete(
|
||||
document_id=doc_id, chunk_count=chunk_count
|
||||
)
|
||||
|
||||
return total_chunks_deleted
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
|
||||
@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
@@ -146,6 +146,11 @@ class SlackRenderer(HTMLRenderer):
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
@@ -218,5 +223,48 @@ class SlackRenderer(HTMLRenderer):
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -439,12 +442,67 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -129,6 +130,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.11.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
|
||||
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import delete_document_set as db_delete_document_set
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import insert_document_set
|
||||
@@ -142,7 +143,10 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if DISABLE_VECTOR_DB:
|
||||
db_session.refresh(document_set)
|
||||
db_delete_document_set(document_set, db_session)
|
||||
else:
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -14,8 +15,6 @@ from onyx.db.llm import remove_llm_provider__no_commit
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
@@ -75,9 +74,9 @@ def _build_llm_provider_request(
|
||||
# Clone mode: Only use API key from source provider
|
||||
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -111,9 +110,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
@@ -125,9 +124,9 @@ def _build_llm_provider_request(
|
||||
)
|
||||
|
||||
if not validate_credentials(provider, credentials):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
f"Incorrect credentials for {provider}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Incorrect credentials for {provider}",
|
||||
)
|
||||
|
||||
return LLMProviderUpsertRequest(
|
||||
@@ -216,9 +215,9 @@ def test_image_generation(
|
||||
LLMProviderModel, test_request.source_llm_provider_id
|
||||
)
|
||||
if not source_provider:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
|
||||
)
|
||||
|
||||
_validate_llm_provider_change(
|
||||
@@ -237,9 +236,9 @@ def test_image_generation(
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No provider or source llm provided",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No provider or source llm provided",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -258,14 +257,14 @@ def test_image_generation(
|
||||
),
|
||||
)
|
||||
except ValueError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"Invalid image generation provider: {provider}",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Invalid image generation provider: {provider}",
|
||||
)
|
||||
except ImageProviderCredentialsError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CREDENTIAL_INVALID,
|
||||
"Invalid image generation credentials",
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid image generation credentials",
|
||||
)
|
||||
|
||||
quality = _get_test_quality_for_model(test_request.model_name)
|
||||
@@ -277,15 +276,15 @@ def test_image_generation(
|
||||
n=1,
|
||||
quality=quality,
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log only exception type to avoid exposing sensitive data
|
||||
# (LiteLLM errors may contain URLs with API keys or auth tokens)
|
||||
logger.warning(f"Image generation test failed: {type(e).__name__}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Image generation test failed: {type(e).__name__}",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image generation test failed: {type(e).__name__}",
|
||||
)
|
||||
|
||||
|
||||
@@ -310,9 +309,9 @@ def create_config(
|
||||
db_session, config_create.image_provider_id
|
||||
)
|
||||
if existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.DUPLICATE_RESOURCE,
|
||||
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -346,10 +345,10 @@ def create_config(
|
||||
db_session.commit()
|
||||
db_session.refresh(config)
|
||||
return ImageGenerationConfigView.from_model(config)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.get("/config")
|
||||
@@ -374,9 +373,9 @@ def get_config_credentials(
|
||||
"""
|
||||
config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
return ImageGenerationCredentials.from_model(config)
|
||||
@@ -402,9 +401,9 @@ def update_config(
|
||||
# 1. Get existing config
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -473,10 +472,10 @@ def update_config(
|
||||
db_session.refresh(existing_config)
|
||||
return ImageGenerationConfigView.from_model(existing_config)
|
||||
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}")
|
||||
@@ -490,9 +489,9 @@ def delete_config(
|
||||
# Get the config first to find the associated LLM provider
|
||||
existing_config = get_image_generation_config(db_session, image_provider_id)
|
||||
if not existing_config:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
|
||||
)
|
||||
|
||||
llm_provider_id = existing_config.model_configuration.llm_provider_id
|
||||
@@ -504,10 +503,10 @@ def delete_config(
|
||||
remove_llm_provider__no_commit(db_session, llm_provider_id)
|
||||
|
||||
db_session.commit()
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/config/{image_provider_id}/default")
|
||||
@@ -520,7 +519,7 @@ def set_config_as_default(
|
||||
try:
|
||||
set_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/config/{image_provider_id}/default")
|
||||
@@ -533,4 +532,4 @@ def unset_config_as_default(
|
||||
try:
|
||||
unset_default_image_generation_config(db_session, image_provider_id)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -85,6 +85,12 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
preferred_voice: str | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -164,6 +170,10 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
preferred_voice=user.preferred_voice,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -240,6 +250,13 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
preferred_voice: str | None = None
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
@@ -13,22 +18,25 @@ from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import create_search_settings
|
||||
from onyx.db.search_settings import delete_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import update_unstructured_api_key
|
||||
from onyx.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from onyx.server.manage.models import FullModelVersionResponse
|
||||
from onyx.server.models import IdReturn
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
@@ -41,110 +49,99 @@ def set_new_search_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session), # noqa: ARG001
|
||||
) -> IdReturn:
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
# TODO(andrei): Re-enable.
|
||||
# NOTE Enable integration external dependency tests in test_search_settings.py
|
||||
# when this is reenabled. They are currently skipped
|
||||
logger.error("Setting new search settings is temporarily disabled.")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
"Setting new search settings is temporarily disabled.",
|
||||
Creates a new SearchSettings row and cancels the previous secondary indexing
|
||||
if any exists.
|
||||
"""
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Disallow contextual RAG for cloud deployments.
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider.
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
model_name=search_settings_new.contextual_rag_llm_name,
|
||||
db_session=db_session,
|
||||
)
|
||||
# if search_settings_new.index_name:
|
||||
# logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# # Disallow contextual RAG for cloud deployments
|
||||
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Contextual RAG disabled in Onyx Cloud",
|
||||
# )
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# # Validate cloud provider exists or create new LiteLLM provider
|
||||
# if search_settings_new.provider_type is not None:
|
||||
# cloud_provider = get_embedding_provider_from_provider_type(
|
||||
# db_session, provider_type=search_settings_new.provider_type
|
||||
# )
|
||||
if search_settings_new.index_name is None:
|
||||
# We define index name here.
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
search_settings_new.model_name == search_settings.model_name
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.model_dump()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**search_settings_new.model_dump()
|
||||
)
|
||||
|
||||
# if cloud_provider is None:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
# )
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# validate_contextual_rag_model(
|
||||
# provider_name=search_settings_new.contextual_rag_llm_provider,
|
||||
# model_name=search_settings_new.contextual_rag_llm_name,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
if secondary_search_settings:
|
||||
# Cancel any background indexing jobs.
|
||||
expire_index_attempts(
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# Mark previous model as a past model directly.
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# if search_settings_new.index_name is None:
|
||||
# # We define index name here
|
||||
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
# if (
|
||||
# search_settings_new.model_name == search_settings.model_name
|
||||
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
# ):
|
||||
# index_name += ALT_INDEX_SUFFIX
|
||||
# search_values = search_settings_new.model_dump()
|
||||
# search_values["index_name"] = index_name
|
||||
# new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
# else:
|
||||
# new_search_settings_request = SavedSearchSettings(
|
||||
# **search_settings_new.model_dump()
|
||||
# )
|
||||
new_search_settings = create_search_settings(
|
||||
search_settings=new_search_settings_request, db_session=db_session
|
||||
)
|
||||
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# Ensure the document indices have the new index immediately.
|
||||
document_indices = get_all_document_indices(search_settings, new_search_settings)
|
||||
for document_index in document_indices:
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=search_settings.embedding_precision,
|
||||
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
)
|
||||
|
||||
# if secondary_search_settings:
|
||||
# # Cancel any background indexing jobs
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
# )
|
||||
# Pause index attempts for the currently in-use index to preserve resources.
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=new_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# # Mark previous model as a past model directly
|
||||
# update_search_settings_status(
|
||||
# search_settings=secondary_search_settings,
|
||||
# new_status=IndexModelStatus.PAST,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# new_search_settings = create_search_settings(
|
||||
# search_settings=new_search_settings_request, db_session=db_session
|
||||
# )
|
||||
|
||||
# # Ensure Vespa has the new index immediately
|
||||
# get_multipass_config(search_settings)
|
||||
# get_multipass_config(new_search_settings)
|
||||
# document_index = get_default_document_index(
|
||||
# search_settings, new_search_settings, db_session
|
||||
# )
|
||||
|
||||
# document_index.ensure_indices_exist(
|
||||
# primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
# primary_embedding_precision=search_settings.embedding_precision,
|
||||
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
# )
|
||||
|
||||
# # Pause index attempts for the currently in use index to preserve resources
|
||||
# if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
# expire_index_attempts(
|
||||
# search_settings_id=search_settings.id, db_session=db_session
|
||||
# )
|
||||
# for cc_pair in get_connector_credential_pairs(db_session):
|
||||
# resync_cc_pair(
|
||||
# cc_pair=cc_pair,
|
||||
# search_settings_id=new_search_settings.id,
|
||||
# db_session=db_session,
|
||||
# )
|
||||
|
||||
# db_session.commit()
|
||||
# return IdReturn(id=new_search_settings.id)
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
|
||||
@@ -191,7 +188,7 @@ def delete_search_settings_endpoint(
|
||||
search_settings_id=deletion_request.search_settings_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
@@ -241,9 +238,9 @@ def update_saved_search_settings(
|
||||
) -> None:
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings.enable_contextual_rag:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Contextual RAG disabled in Onyx Cloud",
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
validate_contextual_rag_model(
|
||||
@@ -297,7 +294,7 @@ def validate_contextual_rag_model(
|
||||
model_name=model_name,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
|
||||
|
||||
|
||||
def _validate_contextual_rag_model(
|
||||
|
||||
0
backend/onyx/server/manage/voice/__init__.py
Normal file
0
backend/onyx/server/manage/voice/__init__.py
Normal file
282
backend/onyx/server/manage/voice/api.py
Normal file
282
backend/onyx/server/manage/voice/api.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = request.target_uri or request.api_base
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Test a voice provider connection."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = request.target_uri or request.api_base
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
# Test the provider by getting available voices (lightweight check)
|
||||
try:
|
||||
voices = provider.get_available_voices()
|
||||
if not voices:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provider returned no available voices.",
|
||||
)
|
||||
except NotImplementedError:
|
||||
# Provider not fully implemented yet (Azure, ElevenLabs placeholders)
|
||||
pass
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connection test failed: {str(e)}",
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise HTTPException(status_code=404, detail="Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return provider.get_available_voices()
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[dict[str, str]]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return provider.get_available_voices()
|
||||
90
backend/onyx/server/manage/voice/models.py
Normal file
90
backend/onyx/server/manage/voice/models.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SynthesizeRequest(BaseModel):
|
||||
"""Request model for text-to-speech synthesis."""
|
||||
|
||||
text: str = Field(..., min_length=1, max_length=4096)
|
||||
voice: str | None = None
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0)
|
||||
226
backend/onyx/server/manage/voice/user_api.py
Normal file
226
backend/onyx/server/manage/voice/user_api.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
audio_data = await audio.read()
|
||||
if len(audio_data) > MAX_AUDIO_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail=f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Transcription failed: {str(exc)}",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice, or user's preferred voice, or provider default
|
||||
final_voice = voice or user.preferred_voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings.
|
||||
|
||||
To clear preferred_voice back to default, explicitly send {"preferred_voice": null}.
|
||||
Omitting the field leaves it unchanged.
|
||||
"""
|
||||
# Build kwargs, only including preferred_voice if explicitly provided
|
||||
# This allows distinguishing between "not provided" and "set to null"
|
||||
kwargs: dict[str, Any] = {
|
||||
"db_session": db_session,
|
||||
"user_id": user.id,
|
||||
"auto_send": request.auto_send,
|
||||
"auto_playback": request.auto_playback,
|
||||
"playback_speed": request.playback_speed,
|
||||
}
|
||||
if "preferred_voice" in request.model_fields_set:
|
||||
kwargs["preferred_voice"] = request.preferred_voice
|
||||
|
||||
update_user_voice_settings(**kwargs)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
await store_ws_token(token, str(user.id))
|
||||
return WSTokenResponse(token=token)
|
||||
782
backend/onyx/server/manage/voice/websocket_api.py
Normal file
782
backend/onyx/server/manage/voice/websocket_api.py
Normal file
@@ -0,0 +1,782 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if key != "type" and isinstance(value, str) and value:
|
||||
text = value
|
||||
break
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
if not text:
|
||||
for key, value in data.items():
|
||||
if key != "type" and isinstance(value, str) and value:
|
||||
text = value
|
||||
break
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
@@ -61,7 +60,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -812,18 +810,6 @@ def fetch_chat_file(
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
|
||||
@@ -60,9 +60,11 @@ class Settings(BaseModel):
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
@@ -84,6 +86,19 @@ class CodeInterpreterClient:
|
||||
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.session = requests.Session()
|
||||
self._closed = False
|
||||
|
||||
def __enter__(self) -> CodeInterpreterClient:
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: object) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self.session.close()
|
||||
self._closed = True
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
@@ -177,8 +192,11 @@ class CodeInterpreterClient:
|
||||
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
yield from self._parse_sse(response)
|
||||
finally:
|
||||
response.close()
|
||||
|
||||
def _parse_sse(
|
||||
self, response: requests.Response
|
||||
|
||||
@@ -111,8 +111,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
if not server.server_enabled:
|
||||
return False
|
||||
|
||||
client = CodeInterpreterClient()
|
||||
return client.health(use_cache=True)
|
||||
with CodeInterpreterClient() as client:
|
||||
return client.health(use_cache=True)
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
@@ -176,196 +176,203 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
# Create Code Interpreter client
|
||||
client = CodeInterpreterClient()
|
||||
# Create Code Interpreter client — context manager ensures
|
||||
# session.close() is called on every exit path.
|
||||
with CodeInterpreterClient() as client:
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
|
||||
# Stage chat files for execution
|
||||
files_to_stage: list[FileInput] = []
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
logger.debug(f"Executing code: {code}")
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
|
||||
logger.info(f"Staged file for Python execution: {file_name}")
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=(
|
||||
event.data if event.stream == "stdout" else ""
|
||||
),
|
||||
stderr=(
|
||||
event.data if event.stream == "stderr" else ""
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stage file {file_name}: {e}")
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(f"Executing code: {code}")
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
|
||||
# Execute code with streaming (falls back to batch if unavailable)
|
||||
stdout_parts: list[str] = []
|
||||
stderr_parts: list[str] = []
|
||||
result_event: StreamResultEvent | None = None
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
|
||||
for event in client.execute_streaming(
|
||||
code=code,
|
||||
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
|
||||
files=files_to_stage or None,
|
||||
):
|
||||
if isinstance(event, StreamOutputEvent):
|
||||
if event.stream == "stdout":
|
||||
stdout_parts.append(event.data)
|
||||
else:
|
||||
stderr_parts.append(event.data)
|
||||
# Emit incremental delta to frontend
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file "
|
||||
f"{workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated "
|
||||
f"file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged "
|
||||
f"file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=event.data if event.stream == "stdout" else "",
|
||||
stderr=event.data if event.stream == "stderr" else "",
|
||||
),
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
)
|
||||
)
|
||||
elif isinstance(event, StreamResultEvent):
|
||||
result_event = event
|
||||
elif isinstance(event, StreamErrorEvent):
|
||||
raise RuntimeError(f"Code interpreter error: {event.message}")
|
||||
|
||||
if result_event is None:
|
||||
raise RuntimeError(
|
||||
"Code interpreter stream ended without a result event"
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=(None if result_event.exit_code == 0 else truncated_stderr),
|
||||
)
|
||||
|
||||
full_stdout = "".join(stdout_parts)
|
||||
full_stderr = "".join(stderr_parts)
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
# Truncate output for LLM consumption
|
||||
truncated_stdout = _truncate_output(
|
||||
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
|
||||
)
|
||||
truncated_stderr = _truncate_output(
|
||||
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
|
||||
)
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
# Handle generated files
|
||||
generated_files: list[PythonExecutionFile] = []
|
||||
generated_file_ids: list[str] = []
|
||||
file_ids_to_cleanup: list[str] = []
|
||||
file_store = get_default_file_store()
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
for workspace_file in result_event.files:
|
||||
if workspace_file.kind != "file" or not workspace_file.file_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Download file from Code Interpreter
|
||||
file_content = client.download_file(workspace_file.file_id)
|
||||
|
||||
# Determine MIME type from file extension
|
||||
filename = workspace_file.path.split("/")[-1]
|
||||
mime_type, _ = mimetypes.guess_type(filename)
|
||||
# Default to binary if we can't determine the type
|
||||
mime_type = mime_type or "application/octet-stream"
|
||||
|
||||
# Save to Onyx file store
|
||||
onyx_file_id = file_store.save_file(
|
||||
content=BytesIO(file_content),
|
||||
display_name=filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=mime_type,
|
||||
)
|
||||
|
||||
generated_files.append(
|
||||
PythonExecutionFile(
|
||||
filename=filename,
|
||||
file_link=build_full_frontend_file_url(onyx_file_id),
|
||||
)
|
||||
)
|
||||
generated_file_ids.append(onyx_file_id)
|
||||
|
||||
# Mark for cleanup
|
||||
file_ids_to_cleanup.append(workspace_file.file_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to handle generated file {workspace_file.path}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup Code Interpreter files (generated files)
|
||||
for ci_file_id in file_ids_to_cleanup:
|
||||
try:
|
||||
client.delete_file(ci_file_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(file_ids=generated_file_ids),
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Build result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout=truncated_stdout,
|
||||
stderr=truncated_stderr,
|
||||
exit_code=result_event.exit_code,
|
||||
timed_out=result_event.timed_out,
|
||||
generated_files=generated_files,
|
||||
error=None if result_event.exit_code == 0 else truncated_stderr,
|
||||
)
|
||||
|
||||
# Serialize result for LLM
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=PythonToolRichResponse(
|
||||
generated_files=generated_files,
|
||||
),
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Python execution failed: {e}")
|
||||
error_msg = str(e)
|
||||
|
||||
# Emit error delta
|
||||
self.emitter.emit(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
file_ids=[],
|
||||
),
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
)
|
||||
|
||||
# Return error result
|
||||
result = LlmPythonExecutionResult(
|
||||
stdout="",
|
||||
stderr=error_msg,
|
||||
exit_code=-1,
|
||||
timed_out=False,
|
||||
generated_files=[],
|
||||
error=error_msg,
|
||||
)
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
adapter = TypeAdapter(LlmPythonExecutionResult)
|
||||
llm_response = adapter.dump_json(result).decode()
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
return ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
0
backend/onyx/voice/__init__.py
Normal file
0
backend/onyx/voice/__init__.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
70
backend/onyx/voice/factory.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
175
backend/onyx/voice/interface.py
Normal file
175
backend/onyx/voice/interface.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
0
backend/onyx/voice/providers/__init__.py
Normal file
0
backend/onyx/voice/providers/__init__.py
Normal file
493
backend/onyx/voice/providers/azure.py
Normal file
493
backend/onyx/voice/providers/azure.py
Normal file
@@ -0,0 +1,493 @@
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore[import-not-found]
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
self.speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
patterns = [
|
||||
r"https?://([^.]+)\.(?:tts|stt)\.speech\.microsoft\.com",
|
||||
r"https?://([^.]+)\.api\.cognitive\.microsoft\.com",
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, uri)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Speech region required for Azure streaming transcription")
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self.speech_region:
|
||||
raise ValueError("Speech region required for Azure streaming TTS")
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
766
backend/onyx/voice/providers/elevenlabs.py
Normal file
766
backend/onyx/voice/providers/elevenlabs.py
Normal file
@@ -0,0 +1,766 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = 24000, # What frontend sends
|
||||
target_sample_rate: int = 16000, # What ElevenLabs expects
|
||||
language_code: str = "en",
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection
|
||||
url = (
|
||||
f"wss://api.elevenlabs.io/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs=1.0"
|
||||
f"&vad_threshold=0.4"
|
||||
f"&min_speech_duration_ms=100"
|
||||
f"&min_silence_duration_ms=300"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == "error":
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types from ElevenLabs Scribe API
|
||||
if msg_type == "session_started":
|
||||
# Session started successfully
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == "partial_transcript":
|
||||
# Partial transcript (interim result)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == "committed_transcript":
|
||||
# Final/committed transcript (VAD detected end of utterance)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "utterance_end":
|
||||
# VAD detected end of speech
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "session_ended":
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
url = (
|
||||
f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": [
|
||||
120,
|
||||
160,
|
||||
250,
|
||||
290,
|
||||
], # Optimized chunk sizes for streaming
|
||||
},
|
||||
"xi_api_key": self.api_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.elevenlabs.io"
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, _speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier (not directly supported, ignored)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model
|
||||
# Frontend sends PCM16 at 24kHz, but ElevenLabs expects 16kHz
|
||||
# The transcriber will resample automatically
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=24000, # What frontend sends
|
||||
target_sample_rate=16000, # What ElevenLabs expects
|
||||
language_code="en",
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, _speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
# Use mp3_44100_64 for streaming - good balance of quality and chunk size
|
||||
output_format="mp3_44100_64",
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
567
backend/onyx/voice/providers/openai.py
Normal file
567
backend/onyx/voice/providers/openai.py
Normal file
@@ -0,0 +1,567 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "whisper-1"):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
url = "wss://api.openai.com/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == "error":
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == "input_audio_buffer.speech_started":
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.speech_stopped":
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.committed":
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == "conversation.item.input_audio_transcription.delta":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif (
|
||||
msg_type
|
||||
== "conversation.item.input_audio_transcription.completed"
|
||||
):
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
"transcription_session.created",
|
||||
"transcription_session.updated",
|
||||
"conversation.item.created",
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = "https://api.openai.com/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
|
||||
# typing-inspect
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
nltk==3.9.1
|
||||
nltk==3.9.3
|
||||
# via unstructured
|
||||
numpy==2.4.1
|
||||
# via
|
||||
|
||||
@@ -16,10 +16,6 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
|
||||
|
||||
def run_jobs() -> None:
|
||||
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
|
||||
use_lightweight = True
|
||||
|
||||
# command setup
|
||||
cmd_worker_primary = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -74,6 +70,48 @@ def run_jobs() -> None:
|
||||
"--queues=connector_doc_fetching",
|
||||
]
|
||||
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
|
||||
]
|
||||
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,user_file_delete",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -82,144 +120,31 @@ def run_jobs() -> None:
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
|
||||
# Prepare background worker commands based on mode
|
||||
if use_lightweight:
|
||||
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
|
||||
cmd_worker_background = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
|
||||
]
|
||||
background_workers = [("BACKGROUND", cmd_worker_background)]
|
||||
else:
|
||||
print("Starting workers in STANDARD mode (separate background workers)")
|
||||
cmd_worker_heavy = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=4",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,sandbox",
|
||||
]
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
]
|
||||
cmd_worker_user_file_processing = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.user_file_processing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=2",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=user_file_processing@%n",
|
||||
"-Q",
|
||||
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
|
||||
]
|
||||
background_workers = [
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
]
|
||||
all_workers = [
|
||||
("PRIMARY", cmd_worker_primary),
|
||||
("LIGHT", cmd_worker_light),
|
||||
("DOCPROCESSING", cmd_worker_docprocessing),
|
||||
("DOCFETCHING", cmd_worker_docfetching),
|
||||
("HEAVY", cmd_worker_heavy),
|
||||
("MONITORING", cmd_worker_monitoring),
|
||||
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
|
||||
("BEAT", cmd_beat),
|
||||
]
|
||||
|
||||
# spawn processes
|
||||
worker_primary_process = subprocess.Popen(
|
||||
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_light_process = subprocess.Popen(
|
||||
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_docprocessing_process = subprocess.Popen(
|
||||
cmd_worker_docprocessing,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
worker_docfetching_process = subprocess.Popen(
|
||||
cmd_worker_docfetching,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
# Spawn background worker processes based on mode
|
||||
background_processes = []
|
||||
for name, cmd in background_workers:
|
||||
processes = []
|
||||
for name, cmd in all_workers:
|
||||
process = subprocess.Popen(
|
||||
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
background_processes.append((name, process))
|
||||
processes.append((name, process))
|
||||
|
||||
# monitor threads
|
||||
worker_primary_thread = threading.Thread(
|
||||
target=monitor_process, args=("PRIMARY", worker_primary_process)
|
||||
)
|
||||
worker_light_thread = threading.Thread(
|
||||
target=monitor_process, args=("LIGHT", worker_light_process)
|
||||
)
|
||||
worker_docprocessing_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
|
||||
)
|
||||
worker_docfetching_thread = threading.Thread(
|
||||
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
# Create monitor threads for background workers
|
||||
background_threads = []
|
||||
for name, process in background_processes:
|
||||
threads = []
|
||||
for name, process in processes:
|
||||
thread = threading.Thread(target=monitor_process, args=(name, process))
|
||||
background_threads.append(thread)
|
||||
|
||||
# Start all threads
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_docprocessing_thread.start()
|
||||
worker_docfetching_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
for thread in background_threads:
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_docprocessing_thread.join()
|
||||
worker_docfetching_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
for thread in background_threads:
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
|
||||
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
|
||||
|
||||
stop_and_remove_containers() {
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
stop_and_remove_containers
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -12,16 +22,26 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
|
||||
|
||||
# Usage of the script with optional volume arguments
|
||||
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
|
||||
# [minio_volume] [--keep-opensearch-data]
|
||||
|
||||
VESPA_VOLUME=${1:-""} # Default is empty if not provided
|
||||
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
|
||||
REDIS_VOLUME=${3:-""} # Default is empty if not provided
|
||||
MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
KEEP_OPENSEARCH_DATA=false
|
||||
POSITIONAL_ARGS=()
|
||||
for arg in "$@"; do
|
||||
if [[ "$arg" == "--keep-opensearch-data" ]]; then
|
||||
KEEP_OPENSEARCH_DATA=true
|
||||
else
|
||||
POSITIONAL_ARGS+=("$arg")
|
||||
fi
|
||||
done
|
||||
|
||||
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
|
||||
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
|
||||
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
|
||||
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
stop_and_remove_containers
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -39,6 +59,29 @@ else
|
||||
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
|
||||
fi
|
||||
|
||||
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
|
||||
# .vscode/.env so existing dev setups that stored it there aren't silently
|
||||
# broken.
|
||||
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
|
||||
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
|
||||
set -a
|
||||
# shellcheck source=/dev/null
|
||||
source "$VSCODE_ENV"
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Start the OpenSearch container using the same service from docker-compose that
|
||||
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
|
||||
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
|
||||
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
|
||||
# restarts, else the volume is deleted so the container starts fresh.
|
||||
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
|
||||
echo "Deleting opensearch-data volume..."
|
||||
docker volume rm onyx_opensearch-data 2>/dev/null || true
|
||||
fi
|
||||
echo "Starting OpenSearch container..."
|
||||
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
|
||||
|
||||
# Start the Redis container with optional volume
|
||||
echo "Starting Redis container..."
|
||||
if [[ -n "$REDIS_VOLUME" ]]; then
|
||||
@@ -60,7 +103,6 @@ echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$PARENT_DIR"
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
|
||||
source "$(dirname "$0")/../../.vscode/.env"
|
||||
|
||||
cd "$(dirname "$0")/../../deployment/docker_compose"
|
||||
|
||||
# Start OpenSearch.
|
||||
echo "Forcefully starting fresh OpenSearch container..."
|
||||
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch
|
||||
@@ -1,23 +1,5 @@
|
||||
#!/bin/sh
|
||||
# Entrypoint script for supervisord that sets environment variables
|
||||
# for controlling which celery workers to start
|
||||
|
||||
# Default to lightweight mode if not set
|
||||
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
|
||||
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
|
||||
fi
|
||||
|
||||
# Set the complementary variable for supervisord
|
||||
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
|
||||
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="false"
|
||||
else
|
||||
export USE_SEPARATE_BACKGROUND_WORKERS="true"
|
||||
fi
|
||||
|
||||
echo "Worker mode configuration:"
|
||||
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
|
||||
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
|
||||
# Entrypoint script for supervisord
|
||||
|
||||
# Launch supervisord with environment variables available
|
||||
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf
|
||||
|
||||
@@ -39,7 +39,6 @@ autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
|
||||
# Standard mode: Light worker for fast operations
|
||||
# NOTE: only allowing configuration here and not in the other celery workers,
|
||||
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
|
||||
# user group syncing, deletion, etc.)
|
||||
@@ -54,26 +53,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Lightweight mode: single consolidated background worker
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
|
||||
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
|
||||
[program:celery_worker_background]
|
||||
command=celery -A onyx.background.celery.versioned_apps.background worker
|
||||
--loglevel=INFO
|
||||
--hostname=background@%%n
|
||||
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
|
||||
stdout_logfile=/var/log/celery_worker_background.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
|
||||
|
||||
# Standard mode: separate workers for different background tasks
|
||||
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
[program:celery_worker_heavy]
|
||||
command=celery -A onyx.background.celery.versioned_apps.heavy worker
|
||||
--loglevel=INFO
|
||||
@@ -85,9 +65,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document processing worker
|
||||
[program:celery_worker_docprocessing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
|
||||
--loglevel=INFO
|
||||
@@ -99,7 +77,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_user_file_processing]
|
||||
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
|
||||
@@ -112,9 +89,7 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
# Standard mode: Document fetching worker
|
||||
[program:celery_worker_docfetching]
|
||||
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
|
||||
--loglevel=INFO
|
||||
@@ -126,7 +101,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
[program:celery_worker_monitoring]
|
||||
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
|
||||
@@ -139,7 +113,6 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
startsecs=10
|
||||
stopasgroup=true
|
||||
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
|
||||
|
||||
|
||||
# Job scheduler for periodic tasks
|
||||
@@ -197,7 +170,6 @@ command=tail -qF
|
||||
/var/log/celery_beat.log
|
||||
/var/log/celery_worker_primary.log
|
||||
/var/log/celery_worker_light.log
|
||||
/var/log/celery_worker_background.log
|
||||
/var/log/celery_worker_heavy.log
|
||||
/var/log/celery_worker_docprocessing.log
|
||||
/var/log/celery_worker_monitoring.log
|
||||
|
||||
@@ -5,6 +5,8 @@ Verifies that:
|
||||
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
|
||||
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
|
||||
3. Upserting is idempotent (running twice doesn't duplicate nodes)
|
||||
4. Document-to-hierarchy-node linkage is updated during pruning
|
||||
5. link_hierarchy_nodes_to_documents links nodes that are also documents
|
||||
|
||||
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
|
||||
combined with a real PostgreSQL database for verifying persistence.
|
||||
@@ -27,9 +29,13 @@ from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.hierarchy import ensure_source_node_exists
|
||||
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
|
||||
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
|
||||
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
|
||||
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.models import HierarchyNode as DBHierarchyNode
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.kg.models import KGStage
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -89,8 +95,18 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
|
||||
]
|
||||
|
||||
|
||||
DOC_PARENT_MAP = {
|
||||
"msg-001": CHANNEL_A_ID,
|
||||
"msg-002": CHANNEL_A_ID,
|
||||
"msg-003": CHANNEL_B_ID,
|
||||
}
|
||||
|
||||
|
||||
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
|
||||
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
|
||||
return [
|
||||
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
|
||||
for doc_id in SLIM_DOC_IDS
|
||||
]
|
||||
|
||||
|
||||
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
@@ -126,14 +142,31 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
|
||||
def _cleanup_test_data(db_session: Session) -> None:
|
||||
"""Remove all test hierarchy nodes and documents to isolate tests."""
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == TEST_SOURCE
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_test_documents(db_session: Session) -> list[DbDocument]:
|
||||
"""Insert minimal Document rows for our test doc IDs."""
|
||||
docs = []
|
||||
for doc_id in SLIM_DOC_IDS:
|
||||
doc = DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id=doc_id,
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
docs.append(doc)
|
||||
db_session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -147,14 +180,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
|
||||
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
|
||||
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
|
||||
expected_ids = {
|
||||
CHANNEL_A_ID,
|
||||
CHANNEL_B_ID,
|
||||
CHANNEL_C_ID,
|
||||
*SLIM_DOC_IDS,
|
||||
}
|
||||
assert result.doc_ids == expected_ids
|
||||
assert result.raw_id_to_parent.keys() == expected_ids
|
||||
|
||||
# Hierarchy nodes should be the 3 channels
|
||||
assert len(result.hierarchy_nodes) == 3
|
||||
@@ -165,7 +198,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
|
||||
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
|
||||
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
|
||||
then verify the DB state (node count, parent relationships, permissions)."""
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
@@ -230,7 +263,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
) -> None:
|
||||
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
|
||||
marked is_public=True regardless of their external_access settings."""
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -257,7 +290,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
|
||||
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
"""Upserting the same hierarchy nodes twice must not create duplicates.
|
||||
The second call should update existing rows in place."""
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -295,7 +328,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
|
||||
|
||||
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
|
||||
"""Upserting a hierarchy node with changed fields should update the existing row."""
|
||||
_cleanup_test_hierarchy_nodes(db_session)
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
|
||||
@@ -342,3 +375,193 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
|
||||
assert db_node.is_public is True
|
||||
assert db_node.external_user_emails is not None
|
||||
assert set(db_node.external_user_emails) == {"new_user@example.com"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Document-to-hierarchy-node linkage tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_extraction_preserves_parent_hierarchy_raw_node_id(
|
||||
db_session: Session, # noqa: ARG001
|
||||
) -> None:
|
||||
"""extract_ids_from_runnable_connector should carry the
|
||||
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
|
||||
connector = MockSlimConnectorWithPermSync()
|
||||
result = extract_ids_from_runnable_connector(connector, callback=None)
|
||||
|
||||
for doc_id, expected_parent in DOC_PARENT_MAP.items():
|
||||
assert (
|
||||
result.raw_id_to_parent[doc_id] == expected_parent
|
||||
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
|
||||
|
||||
# Hierarchy node entries have None parent (they aren't documents)
|
||||
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
|
||||
assert result.raw_id_to_parent[channel_id] is None
|
||||
|
||||
|
||||
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
|
||||
"""update_document_parent_hierarchy_nodes should set
|
||||
Document.parent_hierarchy_node_id for each document in the mapping."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
|
||||
# Create documents with no parent set
|
||||
docs = _create_test_documents(db_session)
|
||||
for doc in docs:
|
||||
assert doc.parent_hierarchy_node_id is None
|
||||
|
||||
# Build resolved map (same logic as _resolve_and_update_document_parents)
|
||||
resolved: dict[str, int | None] = {}
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
|
||||
|
||||
updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert updated == len(SLIM_DOC_IDS)
|
||||
|
||||
# Verify each document now points to the correct hierarchy node
|
||||
db_session.expire_all()
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items():
|
||||
tmp_doc = db_session.get(DbDocument, doc_id)
|
||||
assert tmp_doc is not None
|
||||
doc = tmp_doc
|
||||
expected_node_id = node_id_by_raw[raw_parent]
|
||||
assert (
|
||||
doc.parent_hierarchy_node_id == expected_node_id
|
||||
), f"Document {doc_id} should point to node for {raw_parent}"
|
||||
|
||||
|
||||
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
|
||||
"""Running update_document_parent_hierarchy_nodes a second time with the
|
||||
same mapping should update zero rows."""
|
||||
_cleanup_test_data(db_session)
|
||||
|
||||
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
|
||||
upserted = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=_make_hierarchy_nodes(),
|
||||
source=TEST_SOURCE,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
|
||||
_create_test_documents(db_session)
|
||||
|
||||
resolved: dict[str, int | None] = {
|
||||
doc_id: node_id_by_raw[raw_parent]
|
||||
for doc_id, raw_parent in DOC_PARENT_MAP.items()
|
||||
}
|
||||
|
||||
first_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert first_updated == len(SLIM_DOC_IDS)
|
||||
|
||||
second_updated = update_document_parent_hierarchy_nodes(
|
||||
db_session=db_session,
|
||||
doc_parent_map=resolved,
|
||||
commit=True,
|
||||
)
|
||||
assert second_updated == 0
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_to_documents_for_confluence(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
|
||||
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
|
||||
when a hierarchy node's raw_node_id matches a document ID."""
|
||||
_cleanup_test_data(db_session)
|
||||
confluence_source = DocumentSource.CONFLUENCE
|
||||
|
||||
# Clean up any existing Confluence hierarchy nodes
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
ensure_source_node_exists(db_session, confluence_source, commit=True)
|
||||
|
||||
# Create a hierarchy node whose raw_node_id matches a document ID
|
||||
page_node_id = "confluence-page-123"
|
||||
nodes = [
|
||||
PydanticHierarchyNode(
|
||||
raw_node_id=page_node_id,
|
||||
raw_parent_id=None,
|
||||
display_name="Test Page",
|
||||
link="https://wiki.example.com/page/123",
|
||||
node_type=HierarchyNodeType.PAGE,
|
||||
),
|
||||
]
|
||||
upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=nodes,
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
is_connector_public=False,
|
||||
)
|
||||
|
||||
# Verify the node exists but has no document_id yet
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id is None
|
||||
|
||||
# Create a document with the same ID as the hierarchy node
|
||||
doc = DbDocument(
|
||||
id=page_node_id,
|
||||
semantic_id="Test Page",
|
||||
kg_stage=KGStage.NOT_STARTED,
|
||||
)
|
||||
db_session.add(doc)
|
||||
db_session.commit()
|
||||
|
||||
# Link nodes to documents
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=[page_node_id],
|
||||
source=confluence_source,
|
||||
commit=True,
|
||||
)
|
||||
assert linked == 1
|
||||
|
||||
# Verify the hierarchy node now has document_id set
|
||||
db_session.expire_all()
|
||||
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
|
||||
assert db_node is not None
|
||||
assert db_node.document_id == page_node_id
|
||||
|
||||
# Cleanup
|
||||
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
|
||||
db_session.query(DBHierarchyNode).filter(
|
||||
DBHierarchyNode.source == confluence_source
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""link_hierarchy_nodes_to_documents should return 0 for sources that
|
||||
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
|
||||
linked = link_hierarchy_nodes_to_documents(
|
||||
db_session=db_session,
|
||||
document_ids=SLIM_DOC_IDS,
|
||||
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
|
||||
commit=False,
|
||||
)
|
||||
assert linked == 0
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.context.search.models import SearchSettingsCreationRequest
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_default_contextual_rag_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import IndexModelStatus
|
||||
@@ -37,6 +38,8 @@ def _create_llm_provider_and_model(
|
||||
model_name: str,
|
||||
) -> None:
|
||||
"""Insert an LLM provider with a single visible model configuration."""
|
||||
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
|
||||
return
|
||||
upsert_llm_provider(
|
||||
LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
@@ -146,8 +149,8 @@ def baseline_search_settings(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -155,6 +158,7 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -196,8 +200,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.db.swap_index.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -205,6 +209,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices: MagicMock,
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
@@ -266,7 +271,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
@patch("onyx.server.manage.search_settings.get_all_document_indices")
|
||||
@patch("onyx.server.manage.search_settings.get_default_document_index")
|
||||
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
|
||||
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
|
||||
@@ -274,6 +279,7 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
|
||||
mock_index_handler: MagicMock,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_doc_index: MagicMock, # noqa: ARG001
|
||||
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
|
||||
baseline_search_settings: None, # noqa: ARG001
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
|
||||
@@ -114,8 +114,8 @@ def test_create_duplicate_config_fails(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "already exists" in response.json()["message"]
|
||||
assert response.status_code == 400
|
||||
assert "already exists" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_get_all_configs(
|
||||
@@ -292,7 +292,7 @@ def test_update_config_source_provider_not_found(
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["message"]
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_config(
|
||||
@@ -468,7 +468,7 @@ def test_create_config_missing_credentials(
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "No provider or source llm provided" in response.json()["message"]
|
||||
assert "No provider or source llm provided" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_create_config_source_provider_not_found(
|
||||
@@ -488,4 +488,4 @@ def test_create_config_source_provider_not_found(
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["message"]
|
||||
assert "not found" in response.json()["detail"]
|
||||
|
||||
@@ -42,6 +42,78 @@ class NightlyProviderConfig(BaseModel):
|
||||
strict: bool
|
||||
|
||||
|
||||
def _stringify_custom_config_value(value: object) -> str:
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value)
|
||||
return str(value)
|
||||
|
||||
|
||||
def _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config: dict[object, object],
|
||||
) -> bool:
|
||||
normalized_keys = {str(key).strip().lower() for key in raw_custom_config}
|
||||
provider_specific_keys = {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}
|
||||
if normalized_keys & provider_specific_keys:
|
||||
return False
|
||||
|
||||
normalized_type = str(raw_custom_config.get("type", "")).strip().lower()
|
||||
if normalized_type not in {"service_account", "external_account"}:
|
||||
return False
|
||||
|
||||
# Service account JSON usually includes private_key/client_email, while external
|
||||
# account JSON includes credential_source. Either shape should be accepted.
|
||||
has_service_account_markers = any(
|
||||
key in normalized_keys for key in {"private_key", "client_email"}
|
||||
)
|
||||
has_external_account_markers = "credential_source" in normalized_keys
|
||||
return has_service_account_markers or has_external_account_markers
|
||||
|
||||
|
||||
def _normalize_custom_config(
|
||||
provider: str, raw_custom_config: dict[object, object]
|
||||
) -> dict[str, str]:
|
||||
if provider == "vertex_ai" and _looks_like_vertex_credentials_payload(
|
||||
raw_custom_config
|
||||
):
|
||||
return {"vertex_credentials": json.dumps(raw_custom_config)}
|
||||
|
||||
normalized: dict[str, str] = {}
|
||||
for raw_key, raw_value in raw_custom_config.items():
|
||||
key = str(raw_key).strip()
|
||||
key_lower = key.lower()
|
||||
|
||||
if provider == "vertex_ai":
|
||||
if key_lower in {
|
||||
"vertex_credentials",
|
||||
"credentials_file",
|
||||
"vertex_credentials_file",
|
||||
"google_application_credentials",
|
||||
}:
|
||||
key = "vertex_credentials"
|
||||
elif key_lower in {
|
||||
"vertex_location",
|
||||
"location",
|
||||
"vertex_region",
|
||||
"region",
|
||||
}:
|
||||
key = "vertex_location"
|
||||
|
||||
normalized[key] = _stringify_custom_config_value(raw_value)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
@@ -80,7 +152,9 @@ def _load_provider_config() -> NightlyProviderConfig:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
custom_config = _normalize_custom_config(
|
||||
provider=provider, raw_custom_config=parsed
|
||||
)
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
@@ -148,6 +222,23 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
),
|
||||
)
|
||||
|
||||
if config.provider == "vertex_ai":
|
||||
has_vertex_credentials = bool(
|
||||
config.custom_config and config.custom_config.get("vertex_credentials")
|
||||
)
|
||||
if not has_vertex_credentials:
|
||||
configured_keys = (
|
||||
sorted(config.custom_config.keys()) if config.custom_config else []
|
||||
)
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(
|
||||
f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' "
|
||||
f"for provider '{config.provider}'. "
|
||||
f"Found keys: {configured_keys}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
@@ -193,6 +284,7 @@ def _create_provider_payload(
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"api_version": api_version,
|
||||
@@ -208,24 +300,23 @@ def _create_provider_payload(
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
def _ensure_provider_is_default(
|
||||
provider_id: int, model_name: str, admin_user: DATestUser
|
||||
) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
default_text = list_response.json().get("default_text")
|
||||
assert default_text is not None, "Expected a default provider after setting default"
|
||||
assert default_text.get("provider_id") == provider_id, (
|
||||
f"Expected provider {provider_id} to be default, "
|
||||
f"found {default_text.get('provider_id')}"
|
||||
)
|
||||
assert (
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
default_text.get("model_name") == model_name
|
||||
), f"Expected default model {model_name}, found {default_text.get('model_name')}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
@@ -326,8 +417,9 @@ def _create_and_test_provider_for_model(
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
f"{API_SERVER_URL}/admin/llm/default",
|
||||
headers=admin_user.headers,
|
||||
json={"provider_id": provider_id, "model_name": model_name},
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
@@ -335,7 +427,9 @@ def _create_and_test_provider_for_model(
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_ensure_provider_is_default(
|
||||
provider_id=provider_id, model_name=model_name, admin_user=admin_user
|
||||
)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -300,7 +299,7 @@ def test_update_contextual_rag_nonexistent_provider(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider nonexistent-provider not found" in response.json()["message"]
|
||||
assert "Provider nonexistent-provider not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_contextual_rag_nonexistent_model(
|
||||
@@ -322,7 +321,7 @@ def test_update_contextual_rag_nonexistent_model(
|
||||
assert response.status_code == 400
|
||||
assert (
|
||||
f"Model nonexistent-model not found in provider {llm_provider.name}"
|
||||
in response.json()["message"]
|
||||
in response.json()["detail"]
|
||||
)
|
||||
|
||||
|
||||
@@ -342,7 +341,7 @@ def test_update_contextual_rag_missing_provider_name(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider name and model name are required" in response.json()["message"]
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_update_contextual_rag_missing_model_name(
|
||||
@@ -362,10 +361,9 @@ def test_update_contextual_rag_missing_model_name(
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "Provider name and model name are required" in response.json()["message"]
|
||||
assert "Provider name and model name are required" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_with_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -394,7 +392,6 @@ def test_set_new_search_settings_with_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_without_contextual_rag(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -419,7 +416,6 @@ def test_set_new_search_settings_without_contextual_rag(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_then_update_inference_settings(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
@@ -457,7 +453,6 @@ def test_set_new_then_update_inference_settings(
|
||||
_cancel_new_embedding(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
|
||||
def test_set_new_search_settings_replaces_previous_secondary(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
|
||||
@@ -281,9 +281,10 @@ class TestApplyLicenseStatusToSettings:
|
||||
}
|
||||
|
||||
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
"""Verify the Settings model defaults ee_features_enabled to False."""
|
||||
class TestSettingsDefaults:
|
||||
"""Verify Settings model defaults for CE deployments."""
|
||||
|
||||
def test_default_ee_features_disabled(self) -> None:
|
||||
"""CE default: ee_features_enabled is False."""
|
||||
settings = Settings()
|
||||
assert settings.ee_features_enabled is False
|
||||
|
||||
@@ -104,3 +104,102 @@ def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
|
||||
@@ -87,7 +87,8 @@ def test_python_tool_available_when_health_check_passes(
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = True
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
@@ -109,7 +110,8 @@ def test_python_tool_unavailable_when_health_check_fails(
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.health.return_value = False
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
@@ -138,7 +138,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
|
||||
@@ -52,7 +52,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -65,7 +65,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -70,7 +70,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=${AUTH_TYPE:-oidc}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -58,7 +58,6 @@ services:
|
||||
env_file:
|
||||
- .env_eval
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- AUTH_TYPE=disabled
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
|
||||
@@ -146,7 +146,6 @@ services:
|
||||
- indexing_model_server
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
|
||||
- FILE_STORE_BACKEND=${FILE_STORE_BACKEND:-s3}
|
||||
- POSTGRES_HOST=${POSTGRES_HOST:-relational_db}
|
||||
- VESPA_HOST=${VESPA_HOST:-index}
|
||||
|
||||
@@ -14,30 +14,32 @@ Built with [Tauri](https://tauri.app) for minimal bundle size (~10MB vs Electron
|
||||
|
||||
## Keyboard Shortcuts
|
||||
|
||||
| Shortcut | Action |
|
||||
|----------|--------|
|
||||
| `⌘ N` | New Chat |
|
||||
| `⌘ ⇧ N` | New Window |
|
||||
| `⌘ R` | Reload |
|
||||
| `⌘ [` | Go Back |
|
||||
| `⌘ ]` | Go Forward |
|
||||
| `⌘ ,` | Open Config File |
|
||||
| `⌘ W` | Close Window |
|
||||
| `⌘ Q` | Quit |
|
||||
| Shortcut | Action |
|
||||
| -------- | ---------------- |
|
||||
| `⌘ N` | New Chat |
|
||||
| `⌘ ⇧ N` | New Window |
|
||||
| `⌘ R` | Reload |
|
||||
| `⌘ [` | Go Back |
|
||||
| `⌘ ]` | Go Forward |
|
||||
| `⌘ ,` | Open Config File |
|
||||
| `⌘ W` | Close Window |
|
||||
| `⌘ Q` | Quit |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Rust** (latest stable)
|
||||
|
||||
```bash
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||
source $HOME/.cargo/env
|
||||
```
|
||||
|
||||
2. **Node.js** (18+)
|
||||
|
||||
```bash
|
||||
# Using homebrew
|
||||
brew install node
|
||||
|
||||
|
||||
# Or using nvm
|
||||
nvm install 18
|
||||
```
|
||||
@@ -55,16 +57,21 @@ npm install
|
||||
|
||||
# Run in development mode
|
||||
npm run dev
|
||||
|
||||
# Run in debug mode
|
||||
npm run debug
|
||||
```
|
||||
|
||||
## Building
|
||||
|
||||
### Build for current architecture
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
### Build Universal Binary (Intel + Apple Silicon)
|
||||
|
||||
```bash
|
||||
# First, add the targets
|
||||
rustup target add x86_64-apple-darwin
|
||||
@@ -103,6 +110,7 @@ Before building, add your app icons to `src-tauri/icons/`:
|
||||
- `icon.ico` (Windows, optional)
|
||||
|
||||
You can generate these from a 1024x1024 source image using:
|
||||
|
||||
```bash
|
||||
# Using tauri's icon generator
|
||||
npm run tauri icon path/to/your-icon.png
|
||||
@@ -115,6 +123,7 @@ npm run tauri icon path/to/your-icon.png
|
||||
The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
|
||||
|
||||
**Config file location:**
|
||||
|
||||
- macOS: `~/Library/Application Support/app.onyx.desktop/config.json`
|
||||
- Linux: `~/.config/app.onyx.desktop/config.json`
|
||||
- Windows: `%APPDATA%/app.onyx.desktop/config.json`
|
||||
@@ -135,6 +144,7 @@ The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
|
||||
4. Restart the app
|
||||
|
||||
**Quick edit via terminal:**
|
||||
|
||||
```bash
|
||||
# macOS
|
||||
open -t ~/Library/Application\ Support/app.onyx.desktop/config.json
|
||||
@@ -146,6 +156,7 @@ code ~/Library/Application\ Support/app.onyx.desktop/config.json
|
||||
### Change the default URL in build
|
||||
|
||||
Edit `src-tauri/tauri.conf.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"app": {
|
||||
@@ -165,6 +176,7 @@ Edit `src-tauri/src/main.rs` in the `setup_shortcuts` function.
|
||||
### Window appearance
|
||||
|
||||
Modify the window configuration in `src-tauri/tauri.conf.json`:
|
||||
|
||||
- `titleBarStyle`: `"Overlay"` (macOS native) or `"Visible"`
|
||||
- `decorations`: Window chrome
|
||||
- `transparent`: For custom backgrounds
|
||||
@@ -172,16 +184,20 @@ Modify the window configuration in `src-tauri/tauri.conf.json`:
|
||||
## Troubleshooting
|
||||
|
||||
### "Unable to resolve host"
|
||||
|
||||
Make sure you have an internet connection. The app loads content from `cloud.onyx.app`.
|
||||
|
||||
### Build fails on M1/M2 Mac
|
||||
|
||||
```bash
|
||||
# Ensure you have the right target
|
||||
rustup target add aarch64-apple-darwin
|
||||
```
|
||||
|
||||
### Code signing for distribution
|
||||
|
||||
For distributing outside the App Store, you'll need to:
|
||||
|
||||
1. Get an Apple Developer certificate
|
||||
2. Sign the app: `codesign --deep --force --sign "Developer ID" target/release/bundle/macos/Onyx.app`
|
||||
3. Notarize with Apple
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"description": "Lightweight desktop app for Onyx Cloud",
|
||||
"scripts": {
|
||||
"dev": "tauri dev",
|
||||
"debug": "tauri dev -- -- --debug",
|
||||
"build": "tauri build",
|
||||
"build:dmg": "tauri build --target universal-apple-darwin",
|
||||
"build:linux": "tauri build --bundles deb,rpm"
|
||||
|
||||
@@ -23,3 +23,4 @@ url = "2.5"
|
||||
[features]
|
||||
default = ["custom-protocol"]
|
||||
custom-protocol = ["tauri/custom-protocol"]
|
||||
devtools = ["tauri/devtools"]
|
||||
|
||||
@@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::{Mutex, RwLock};
|
||||
use std::io::Write as IoWrite;
|
||||
use std::time::SystemTime;
|
||||
#[cfg(target_os = "macos")]
|
||||
use std::time::Duration;
|
||||
use tauri::image::Image;
|
||||
@@ -230,6 +232,63 @@ const MENU_KEY_HANDLER_SCRIPT: &str = r#"
|
||||
})();
|
||||
"#;
|
||||
|
||||
const CONSOLE_CAPTURE_SCRIPT: &str = r#"
|
||||
(() => {
|
||||
if (window.__ONYX_CONSOLE_CAPTURE__) return;
|
||||
window.__ONYX_CONSOLE_CAPTURE__ = true;
|
||||
|
||||
const levels = ['log', 'warn', 'error', 'info', 'debug'];
|
||||
const originals = {};
|
||||
|
||||
levels.forEach(level => {
|
||||
originals[level] = console[level];
|
||||
console[level] = function(...args) {
|
||||
originals[level].apply(console, args);
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
const message = args.map(a => {
|
||||
try { return typeof a === 'string' ? a : JSON.stringify(a); }
|
||||
catch { return String(a); }
|
||||
}).join(' ');
|
||||
invoke('log_from_frontend', { level, message });
|
||||
}
|
||||
} catch {}
|
||||
};
|
||||
});
|
||||
|
||||
window.addEventListener('error', (event) => {
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
invoke('log_from_frontend', {
|
||||
level: 'error',
|
||||
message: `[uncaught] ${event.message} at ${event.filename}:${event.lineno}:${event.colno}`
|
||||
});
|
||||
}
|
||||
} catch {}
|
||||
});
|
||||
|
||||
window.addEventListener('unhandledrejection', (event) => {
|
||||
try {
|
||||
const invoke =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof invoke === 'function') {
|
||||
invoke('log_from_frontend', {
|
||||
level: 'error',
|
||||
message: `[unhandled rejection] ${event.reason}`
|
||||
});
|
||||
}
|
||||
} catch {}
|
||||
});
|
||||
})();
|
||||
"#;
|
||||
|
||||
const MENU_TOGGLE_DEVTOOLS_ID: &str = "toggle_devtools";
|
||||
const MENU_OPEN_DEBUG_LOG_ID: &str = "open_debug_log";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
pub server_url: String,
|
||||
@@ -311,12 +370,87 @@ fn save_config(config: &AppConfig) -> Result<(), String> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Debug Mode
|
||||
// ============================================================================
|
||||
|
||||
fn is_debug_mode() -> bool {
|
||||
std::env::args().any(|arg| arg == "--debug") || std::env::var("ONYX_DEBUG").is_ok()
|
||||
}
|
||||
|
||||
fn get_debug_log_path() -> Option<PathBuf> {
|
||||
get_config_dir().map(|dir| dir.join("frontend_debug.log"))
|
||||
}
|
||||
|
||||
fn init_debug_log_file() -> Option<fs::File> {
|
||||
let log_path = get_debug_log_path()?;
|
||||
if let Some(parent) = log_path.parent() {
|
||||
let _ = fs::create_dir_all(parent);
|
||||
}
|
||||
fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&log_path)
|
||||
.ok()
|
||||
}
|
||||
|
||||
fn format_utc_timestamp() -> String {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
let total_secs = now.as_secs();
|
||||
let millis = now.subsec_millis();
|
||||
|
||||
let days = total_secs / 86400;
|
||||
let secs_of_day = total_secs % 86400;
|
||||
let hours = secs_of_day / 3600;
|
||||
let mins = (secs_of_day % 3600) / 60;
|
||||
let secs = secs_of_day % 60;
|
||||
|
||||
// Days since Unix epoch -> Y/M/D via civil calendar arithmetic
|
||||
let z = days as i64 + 719468;
|
||||
let era = z / 146097;
|
||||
let doe = z - era * 146097;
|
||||
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||
let y = yoe + era * 400;
|
||||
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||
let mp = (5 * doy + 2) / 153;
|
||||
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||
let y = if m <= 2 { y + 1 } else { y };
|
||||
|
||||
format!(
|
||||
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
|
||||
y, m, d, hours, mins, secs, millis
|
||||
)
|
||||
}
|
||||
|
||||
fn inject_console_capture(webview: &Webview) {
|
||||
let _ = webview.eval(CONSOLE_CAPTURE_SCRIPT);
|
||||
}
|
||||
|
||||
fn maybe_open_devtools(app: &AppHandle, window: &tauri::WebviewWindow) {
|
||||
#[cfg(any(debug_assertions, feature = "devtools"))]
|
||||
{
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
window.open_devtools();
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(debug_assertions, feature = "devtools")))]
|
||||
{
|
||||
let _ = (app, window);
|
||||
}
|
||||
}
|
||||
|
||||
// Global config state
|
||||
struct ConfigState {
|
||||
config: RwLock<AppConfig>,
|
||||
config_initialized: RwLock<bool>,
|
||||
app_base_url: RwLock<Option<Url>>,
|
||||
menu_temporarily_visible: RwLock<bool>,
|
||||
debug_mode: bool,
|
||||
debug_log_file: Mutex<Option<fs::File>>,
|
||||
}
|
||||
|
||||
fn focus_main_window(app: &AppHandle) {
|
||||
@@ -372,6 +506,7 @@ fn trigger_new_window(app: &AppHandle) {
|
||||
}
|
||||
|
||||
apply_settings_to_window(&handle, &window);
|
||||
maybe_open_devtools(&handle, &window);
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
});
|
||||
@@ -467,10 +602,65 @@ fn inject_chat_link_intercept(webview: &Webview) {
|
||||
let _ = webview.eval(CHAT_LINK_INTERCEPT_SCRIPT);
|
||||
}
|
||||
|
||||
fn handle_toggle_devtools(app: &AppHandle) {
|
||||
#[cfg(any(debug_assertions, feature = "devtools"))]
|
||||
{
|
||||
let windows: Vec<_> = app.webview_windows().into_values().collect();
|
||||
let any_open = windows.iter().any(|w| w.is_devtools_open());
|
||||
for window in &windows {
|
||||
if any_open {
|
||||
window.close_devtools();
|
||||
} else {
|
||||
window.open_devtools();
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(debug_assertions, feature = "devtools")))]
|
||||
{
|
||||
let _ = app;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_open_debug_log() {
|
||||
let log_path = match get_debug_log_path() {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if !log_path.exists() {
|
||||
eprintln!("[ONYX DEBUG] Log file does not exist yet: {:?}", log_path);
|
||||
return;
|
||||
}
|
||||
|
||||
let url_path = log_path.to_string_lossy().replace('\\', "/");
|
||||
let _ = open_in_default_browser(&format!(
|
||||
"file:///{}",
|
||||
url_path.trim_start_matches('/')
|
||||
));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tauri Commands
|
||||
// ============================================================================
|
||||
|
||||
#[tauri::command]
|
||||
fn log_from_frontend(level: String, message: String, state: tauri::State<ConfigState>) {
|
||||
if !state.debug_mode {
|
||||
return;
|
||||
}
|
||||
let timestamp = format_utc_timestamp();
|
||||
let log_line = format!("[{}] [{}] {}", timestamp, level.to_uppercase(), message);
|
||||
|
||||
eprintln!("{}", log_line);
|
||||
|
||||
if let Ok(mut guard) = state.debug_log_file.lock() {
|
||||
if let Some(ref mut file) = *guard {
|
||||
let _ = writeln!(file, "{}", log_line);
|
||||
let _ = file.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current server URL
|
||||
#[tauri::command]
|
||||
fn get_server_url(state: tauri::State<ConfigState>) -> String {
|
||||
@@ -657,6 +847,7 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
|
||||
}
|
||||
|
||||
apply_settings_to_window(&app, &window);
|
||||
maybe_open_devtools(&app, &window);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -936,6 +1127,30 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
menu.append(&help_menu)?;
|
||||
}
|
||||
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
let toggle_devtools_item = MenuItem::with_id(
|
||||
app,
|
||||
MENU_TOGGLE_DEVTOOLS_ID,
|
||||
"Toggle DevTools",
|
||||
true,
|
||||
Some("F12"),
|
||||
)?;
|
||||
let open_log_item = MenuItem::with_id(
|
||||
app,
|
||||
MENU_OPEN_DEBUG_LOG_ID,
|
||||
"Open Debug Log",
|
||||
true,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
let debug_menu = SubmenuBuilder::new(app, "Debug")
|
||||
.item(&toggle_devtools_item)
|
||||
.item(&open_log_item)
|
||||
.build()?;
|
||||
menu.append(&debug_menu)?;
|
||||
}
|
||||
|
||||
app.set_menu(menu)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1027,8 +1242,20 @@ fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> {
|
||||
// ============================================================================
|
||||
|
||||
fn main() {
|
||||
// Load config at startup
|
||||
let (config, config_initialized) = load_config();
|
||||
let debug_mode = is_debug_mode();
|
||||
|
||||
let debug_log_file = if debug_mode {
|
||||
eprintln!("[ONYX DEBUG] Debug mode enabled");
|
||||
if let Some(path) = get_debug_log_path() {
|
||||
eprintln!("[ONYX DEBUG] Frontend logs: {}", path.display());
|
||||
}
|
||||
eprintln!("[ONYX DEBUG] DevTools will open automatically");
|
||||
eprintln!("[ONYX DEBUG] Capturing console.log/warn/error/info/debug from webview");
|
||||
init_debug_log_file()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
@@ -1059,6 +1286,8 @@ fn main() {
|
||||
config_initialized: RwLock::new(config_initialized),
|
||||
app_base_url: RwLock::new(None),
|
||||
menu_temporarily_visible: RwLock::new(false),
|
||||
debug_mode,
|
||||
debug_log_file: Mutex::new(debug_log_file),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_server_url,
|
||||
@@ -1077,7 +1306,8 @@ fn main() {
|
||||
start_drag_window,
|
||||
toggle_menu_bar,
|
||||
show_menu_bar_temporarily,
|
||||
hide_menu_bar_temporary
|
||||
hide_menu_bar_temporary,
|
||||
log_from_frontend
|
||||
])
|
||||
.on_menu_event(|app, event| match event.id().as_ref() {
|
||||
"open_docs" => open_docs(),
|
||||
@@ -1086,6 +1316,8 @@ fn main() {
|
||||
"open_settings" => open_settings(app),
|
||||
"show_menu_bar" => handle_menu_bar_toggle(app),
|
||||
"hide_window_decorations" => handle_decorations_toggle(app),
|
||||
MENU_TOGGLE_DEVTOOLS_ID => handle_toggle_devtools(app),
|
||||
MENU_OPEN_DEBUG_LOG_ID => handle_open_debug_log(),
|
||||
_ => {}
|
||||
})
|
||||
.setup(move |app| {
|
||||
@@ -1119,6 +1351,7 @@ fn main() {
|
||||
inject_titlebar(window.clone());
|
||||
|
||||
apply_settings_to_window(&app_handle, &window);
|
||||
maybe_open_devtools(&app_handle, &window);
|
||||
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
@@ -1128,6 +1361,14 @@ fn main() {
|
||||
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
|
||||
inject_chat_link_intercept(webview);
|
||||
|
||||
{
|
||||
let app = webview.app_handle();
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.debug_mode {
|
||||
inject_console_capture(webview);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);
|
||||
|
||||
6
uv.lock
generated
6
uv.lock
generated
@@ -4106,7 +4106,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "nltk"
|
||||
version = "3.9.1"
|
||||
version = "3.9.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
@@ -4114,9 +4114,9 @@ dependencies = [
|
||||
{ name = "regex" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
20
web/lib/opal/src/icons/audio.tsx
Normal file
20
web/lib/opal/src/icons/audio.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAudio = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 32 32"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M4 20V12M10 28V4M22 22V10M28 18V14M16 20V12"
|
||||
strokeWidth={2.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAudio;
|
||||
22
web/lib/opal/src/icons/column.tsx
Normal file
22
web/lib/opal/src/icons/column.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgColumn = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M6 14H3.33333C2.59695 14 2 13.403 2 12.6667V3.33333C2 2.59695 2.59695 2 3.33333 2H6M6 14V2M6 14H10M6 2H10M10 2H12.6667C13.403 2 14 2.59695 14 3.33333V12.6667C14 13.403 13.403 14 12.6667 14H10M10 2V14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgColumn;
|
||||
20
web/lib/opal/src/icons/handle.tsx
Normal file
20
web/lib/opal/src/icons/handle.tsx
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgHandle = ({ size = 16, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={Math.round((size * 3) / 17)}
|
||||
height={size}
|
||||
viewBox="0 0 3 17"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M0.5 0.5V16.5M2.5 0.5V16.5"
|
||||
stroke="currentColor"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgHandle;
|
||||
@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
|
||||
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
|
||||
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudio } from "@opal/icons/audio";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
@@ -49,6 +50,7 @@ export { default as SvgClock } from "@opal/icons/clock";
|
||||
export { default as SvgClockHandsSmall } from "@opal/icons/clock-hands-small";
|
||||
export { default as SvgCloud } from "@opal/icons/cloud";
|
||||
export { default as SvgCode } from "@opal/icons/code";
|
||||
export { default as SvgColumn } from "@opal/icons/column";
|
||||
export { default as SvgCopy } from "@opal/icons/copy";
|
||||
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
|
||||
export { default as SvgCpu } from "@opal/icons/cpu";
|
||||
@@ -79,6 +81,7 @@ export { default as SvgFolderPartialOpen } from "@opal/icons/folder-partial-open
|
||||
export { default as SvgFolderPlus } from "@opal/icons/folder-plus";
|
||||
export { default as SvgGemini } from "@opal/icons/gemini";
|
||||
export { default as SvgGlobe } from "@opal/icons/globe";
|
||||
export { default as SvgHandle } from "@opal/icons/handle";
|
||||
export { default as SvgHardDrive } from "@opal/icons/hard-drive";
|
||||
export { default as SvgHashSmall } from "@opal/icons/hash-small";
|
||||
export { default as SvgHash } from "@opal/icons/hash";
|
||||
@@ -103,6 +106,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
|
||||
export { default as SvgMcp } from "@opal/icons/mcp";
|
||||
export { default as SvgMenu } from "@opal/icons/menu";
|
||||
export { default as SvgMicrophone } from "@opal/icons/microphone";
|
||||
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
|
||||
export { default as SvgMinus } from "@opal/icons/minus";
|
||||
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
|
||||
export { default as SvgMoon } from "@opal/icons/moon";
|
||||
@@ -146,6 +151,8 @@ export { default as SvgSlack } from "@opal/icons/slack";
|
||||
export { default as SvgSlash } from "@opal/icons/slash";
|
||||
export { default as SvgSliders } from "@opal/icons/sliders";
|
||||
export { default as SvgSlidersSmall } from "@opal/icons/sliders-small";
|
||||
export { default as SvgSort } from "@opal/icons/sort";
|
||||
export { default as SvgSortOrder } from "@opal/icons/sort-order";
|
||||
export { default as SvgSparkle } from "@opal/icons/sparkle";
|
||||
export { default as SvgStar } from "@opal/icons/star";
|
||||
export { default as SvgStep1 } from "@opal/icons/step1";
|
||||
@@ -169,7 +176,10 @@ export { default as SvgUploadCloud } from "@opal/icons/upload-cloud";
|
||||
export { default as SvgUser } from "@opal/icons/user";
|
||||
export { default as SvgUserManage } from "@opal/icons/user-manage";
|
||||
export { default as SvgUserPlus } from "@opal/icons/user-plus";
|
||||
export { default as SvgUserSync } from "@opal/icons/user-sync";
|
||||
export { default as SvgUsers } from "@opal/icons/users";
|
||||
export { default as SvgVolume } from "@opal/icons/volume";
|
||||
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
|
||||
export { default as SvgWallet } from "@opal/icons/wallet";
|
||||
export { default as SvgWorkflow } from "@opal/icons/workflow";
|
||||
export { default as SvgX } from "@opal/icons/x";
|
||||
|
||||
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
29
web/lib/opal/src/icons/microphone-off.tsx
Normal file
@@ -0,0 +1,29 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
{/* Microphone body */}
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
{/* Diagonal slash */}
|
||||
<path
|
||||
d="M2 2L14 14"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophoneOff;
|
||||
21
web/lib/opal/src/icons/microphone.tsx
Normal file
21
web/lib/opal/src/icons/microphone.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgMicrophone = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgMicrophone;
|
||||
21
web/lib/opal/src/icons/sort-order.tsx
Normal file
21
web/lib/opal/src/icons/sort-order.tsx
Normal file
@@ -0,0 +1,21 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgSortOrder = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2.66675 12L7.67009 12.0001M2.66675 8H10.5001M2.66675 4H13.3334"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgSortOrder;
|
||||
27
web/lib/opal/src/icons/sort.tsx
Normal file
27
web/lib/opal/src/icons/sort.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgSort = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 4.5H10M2 8H7M2 11.5H5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M12 5V12M12 12L14 10M12 12L10 10"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgSort;
|
||||
22
web/lib/opal/src/icons/user-sync.tsx
Normal file
22
web/lib/opal/src/icons/user-sync.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
|
||||
const SvgUserSync = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M1 14C1 13.6667 1 13.3333 1 13C1 11.3431 2.34316 10 4.00002 10H7M11 8.5L9.5 10L14.5 9.99985M13 14L14.5 12.5L9.5 12.5M8.75 4.75C8.75 6.26878 7.51878 7.5 6 7.5C4.48122 7.5 3.25 6.26878 3.25 4.75C3.25 3.23122 4.48122 2 6 2C7.51878 2 8.75 3.23122 8.75 4.75Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgUserSync;
|
||||
26
web/lib/opal/src/icons/volume-off.tsx
Normal file
26
web/lib/opal/src/icons/volume-off.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M14 6L11 9M11 6L14 9"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolumeOff;
|
||||
26
web/lib/opal/src/icons/volume.tsx
Normal file
26
web/lib/opal/src/icons/volume.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgVolume = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M2 6V10H5L9 13V3L5 6H2Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
export default SvgVolume;
|
||||
@@ -58,7 +58,7 @@ const nextConfig = {
|
||||
{
|
||||
key: "Permissions-Policy",
|
||||
value:
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
8
web/package-lock.json
generated
8
web/package-lock.json
generated
@@ -105,6 +105,7 @@
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "19.2.10",
|
||||
"@types/react-dom": "19.2.3",
|
||||
"@types/sbd": "^1.0.5",
|
||||
"@types/stats.js": "^0.17.4",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251222.1",
|
||||
@@ -5543,6 +5544,13 @@
|
||||
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/sbd": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/@types/sbd/-/sbd-1.0.5.tgz",
|
||||
"integrity": "sha512-60PxBBWhg0C3yb5bTP+wwWYGTKMcuB0S6mTEa1sedMC79tYY0Ei7YjU4qsWzGn++lWscLQde16SnElJrf5/aTw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/stack-utils": {
|
||||
"version": "2.0.3",
|
||||
"dev": true,
|
||||
|
||||
@@ -121,6 +121,7 @@
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "19.2.10",
|
||||
"@types/react-dom": "19.2.3",
|
||||
"@types/sbd": "^1.0.5",
|
||||
"@types/stats.js": "^0.17.4",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"@typescript/native-preview": "7.0.0-dev.20251222.1",
|
||||
|
||||
@@ -41,7 +41,7 @@ export default defineConfig({
|
||||
viewport: { width: 1280, height: 720 },
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
grepInvert: /@exclusive/,
|
||||
grepInvert: [/@exclusive/, /@lite/],
|
||||
},
|
||||
{
|
||||
// this suite runs independently and serially + slower
|
||||
@@ -55,5 +55,15 @@ export default defineConfig({
|
||||
grep: /@exclusive/,
|
||||
workers: 1,
|
||||
},
|
||||
{
|
||||
// runs against the Onyx Lite stack (DISABLE_VECTOR_DB=true, no Vespa/Redis)
|
||||
name: "lite",
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
viewport: { width: 1280, height: 720 },
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
grep: /@lite/,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
4
web/public/ElevenLabs.svg
Normal file
4
web/public/ElevenLabs.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.5 2H13V14H10.5V2Z" fill="black"/>
|
||||
<path d="M3 2H5.5V14H3V2Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 192 B |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user