Compare commits

..

51 Commits

Author SHA1 Message Date
Jessica Singh
8fead4dfbf jest test 2026-03-05 21:55:58 -08:00
Jessica Singh
ac4b49a7f9 sync to autoplayback text generation 2026-03-05 21:50:07 -08:00
Jessica Singh
fc22232f14 fix: mypy type errors in websocket_api 2026-03-05 17:41:15 -08:00
Jessica Singh
9ddd44bf56 fix: prevent UnboundLocalError in TTS fallback 2026-03-05 17:27:00 -08:00
Jessica Singh
8587911cf6 fix: critical bugs from PR review 2026-03-05 17:25:52 -08:00
Jessica Singh
d7300d50d7 fix: address PR review feedback 2026-03-05 17:19:21 -08:00
Jessica Singh
cc950a2da2 fix: session lifecycle, atomic WS token, and clearable voice prefs 2026-03-05 16:32:38 -08:00
Jessica Singh
8d6640159a fix: address PR bot review feedback (security, dead code, logging) 2026-03-05 16:09:31 -08:00
Jessica Singh
bba77749c3 fix: position recording bar above input on subsequent turns 2026-03-05 14:33:23 -08:00
Jessica Singh
3e9a66c8ff chore: add @types/sbd for TypeScript support 2026-03-05 14:25:12 -08:00
Jessica Singh
548b9d9e0e fix: remove unused type ignore in azure.py 2026-03-05 14:13:42 -08:00
Jessica Singh
0d3967baee mypy 2026-03-05 13:50:27 -08:00
Jessica Singh
6ed806eebb migration 2026-03-05 09:32:35 -08:00
Jessica Singh
3b6a35b2c4 recording bar + bug fixes 2026-03-04 21:55:42 -08:00
Jessica Singh
62e612f85f Merge branch 'main' into voice-mode 2026-03-04 21:25:37 -08:00
Jamison Lahman
7eabfa125c fix(fe): properly wrap copy and edit buttons on mobile (#9073) 2026-03-05 04:36:11 +00:00
SubashMohan
ee18114739 feat(table): add DataTable config-driven wrapper component (#9020)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-03-05 04:21:38 +00:00
Jessica Singh
b375b7f0ff azure 2026-03-04 20:14:18 -08:00
Nikolas Garza
f7630f5648 fix: EE route gating for upgrading CE users (#9026) 2026-03-05 03:44:16 +00:00
Jamison Lahman
e0d91b9ea7 chore(fe): rm unreachable code (#9069) 2026-03-05 03:26:50 +00:00
Raunak Bhagat
2c0a4a60a5 refactor: consolidate AppInputBar search/chat rendering with animated transitions (#9054) 2026-03-05 03:16:36 +00:00
Justin Tahara
3a7d4dad56 fix(ui): Improve text truncation and overflow handling in FileCard layout (#9061) 2026-03-05 03:11:53 +00:00
acaprau
c5c236d098 chore(opensearch): Fix and consolidate the dev script used to start OpenSearch locally (#9036) 2026-03-05 01:54:02 +00:00
Danelegend
b18baff4d0 fix: Correct file_id for docs (#9058) 2026-03-05 01:43:58 +00:00
SubashMohan
eb3e15c195 feat(table): add ColumnVisibilityPopover, Footer, Pagination, and SortingPopover components (#9019)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-03-05 01:43:37 +00:00
acaprau
47d9a9e1ac feat(document index): Re-enable search settings swap (#9005) 2026-03-05 01:41:03 +00:00
Evan Lohn
aca466b35d fix: doc to hierarchynode connection in pruning (#9046) 2026-03-05 01:30:36 +00:00
Jessica Singh
c158ae2622 remove logs 2026-03-04 17:02:44 -08:00
Justin Tahara
5176fd7386 fix(llm): Final LLM Cleanup for Nightly Tests (#9055) 2026-03-05 01:00:45 +00:00
Jessica Singh
698494626f eleven labs and bug fixes 2026-03-04 16:40:26 -08:00
SubashMohan
92538084e9 feat(table): add useColumnWidths, useDataTable, and useDraggableRows hooks (#9018)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-03-05 00:00:06 +00:00
Bo-Onyx
2d996e05a4 chore(fe): opal button migration (#8864) 2026-03-04 22:52:49 +00:00
Nikolas Garza
b2956f795b refactor: migrate LLM & embedding management to OnyxError (#9025) 2026-03-04 22:09:25 +00:00
Danelegend
b272085543 fix: Code Interpreter Client session clean up (#9028)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-04 21:58:00 +00:00
Justin Tahara
8193aa4fd0 fix(ui): Persist agent sharing changes immediately for existing agents (#9024) 2026-03-04 21:34:50 +00:00
dependabot[bot]
52db41a00b chore(deps): bump nltk from 3.9.1 to 3.9.3 (#9045)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-03-04 21:21:37 +00:00
SubashMohan
f1cf3c4589 feat(table): add table primitive components and styles (#9017) 2026-03-04 21:06:53 +00:00
dependabot[bot]
5322aeed90 chore(deps): bump hono from 4.11.7 to 4.12.5 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#9044)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-04 12:51:05 -08:00
Evan Lohn
5da8870fd2 fix: stop calling unsupported endpoints no vectordb (#9012) 2026-03-04 20:18:09 +00:00
Nikolas Garza
57d3ab3b40 feat: add SCIM token management page (#9001) 2026-03-04 19:48:37 +00:00
Nikolas Garza
649c7fe8b9 feat(slack): convert markdown tables to Slack-friendly format (#8999) 2026-03-04 19:16:50 +00:00
Jamison Lahman
e5e2bc6149 chore(fe): "Share Chat"->"Share" (#9022) 2026-03-04 11:08:14 -08:00
Jamison Lahman
b148065e1d chore(devtools): --debug mode for desktop (#9027) 2026-03-04 11:07:52 -08:00
Evan Lohn
367808951c chore: remove lightweight mode (#9014) 2026-03-04 18:26:05 +00:00
Jessica Singh
93cefe7ef0 chore: trigger Greptile review 2026-03-03 23:04:53 -08:00
Jessica Singh
8a326c4089 address greptile review feedback (greploop iteration 2)
- Narrow WebSocket auth bypass to only voice endpoints in auth_check.py
- Add query param validation (max_length, ge/le) for TTS synthesize endpoint
- Fix ObjectURL memory leak in useVoicePlayback.ts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:39:12 -08:00
Jessica Singh
0c5410f429 address greptile review feedback (greploop iteration 1)
- Add WebSocket authentication to /voice/transcribe/stream and /voice/synthesize/stream endpoints
- Fix useVoicePlayback.ts to use query params instead of JSON body (matches API signature)
- Fix delete_voice_provider to use flush() instead of commit() for consistency
- Disable Azure streaming STT until audio resampling is implemented
- Add SSML escaping to prevent injection in Azure TTS
- Remove debug console.log statements from voice components
- Fix blob URL memory leak in VoiceModeProvider

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-03 22:31:59 -08:00
Jessica Singh
0b05b9b235 fix(voice): move error toast to useEffect 2026-03-03 17:26:20 -08:00
Jessica Singh
59d8a988bd streaming tts 2026-03-03 03:37:18 -08:00
Jessica Singh
6d08cfb25a all changes 2026-03-02 13:16:39 -08:00
Jessica Singh
53a5ee2a6e stt and tts 2026-02-23 18:27:37 -08:00
376 changed files with 16761 additions and 2398 deletions

View File

@@ -335,7 +335,6 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi

View File

@@ -268,10 +268,11 @@ jobs:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
@@ -279,6 +280,7 @@ jobs:
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
@@ -590,6 +592,108 @@ jobs:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
playwright-tests-lite:
needs: [build-web-image, build-backend-image]
name: Playwright Tests (lite)
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-playwright-tests-lite"
- "extras=ecr-cache"
timeout-minutes: 30
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
cache-dependency-path: ./web/package-lock.json
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Cache playwright cache
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: ~/.cache/ms-playwright
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
restore-keys: |
${{ runner.os }}-playwright-npm-
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Create .env file for Docker Compose
env:
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
MOCK_LLM_RESPONSE=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
EOF
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Start Docker containers (lite)
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
id: start_docker
- name: Run Playwright tests (lite)
working-directory: ./web
run: npx playwright test --project lite
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
name: playwright-test-results-lite-${{ github.run_id }}
path: ./web/output/playwright/
retention-days: 30
- name: Save Docker logs
if: success() || failure()
env:
WORKSPACE: ${{ github.workspace }}
run: |
cd deployment/docker_compose
docker compose logs > docker-compose.log
mv docker-compose.log ${WORKSPACE}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-lite-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
@@ -686,7 +790,7 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [playwright-tests]
needs: [playwright-tests, playwright-tests-lite]
if: ${{ always() }}
steps:
- name: Check job status

58
.vscode/launch.json vendored
View File

@@ -40,19 +40,7 @@
}
},
{
"name": "Celery (lightweight mode)",
"configurations": [
"Celery primary",
"Celery background",
"Celery beat"
],
"presentation": {
"group": "1"
},
"stopAll": true
},
{
"name": "Celery (standard mode)",
"name": "Celery",
"configurations": [
"Celery primary",
"Celery light",
@@ -253,35 +241,6 @@
},
"consoleTitle": "Celery light Console"
},
{
"name": "Celery background",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "INFO",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=20",
"--prefetch-multiplier=4",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
},
"consoleTitle": "Celery background Console"
},
{
"name": "Celery heavy",
"type": "debugpy",
@@ -526,21 +485,6 @@
"group": "3"
}
},
{
"name": "Clear and Restart OpenSearch Container",
// Generic debugger type, required arg but has no bearing on bash.
"type": "node",
"request": "launch",
"runtimeExecutable": "bash",
"runtimeArgs": [
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "3"
}
},
{
"name": "Eval CLI",
"type": "debugpy",

View File

@@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
- Monitoring tasks (every 5 minutes)
- Cleanup tasks (hourly)
#### Worker Deployment Modes
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
- Runs a single consolidated `background` worker that handles all background tasks:
- Light worker tasks (Vespa operations, permissions sync, deletion)
- Document processing (indexing pipeline)
- Document fetching (connector data retrieval)
- Pruning operations (from `heavy` worker)
- Knowledge graph processing (from `kg_processing` worker)
- Monitoring tasks (from `monitoring` worker)
- User file processing (from `user_file_processing` worker)
- Lower resource footprint (fewer worker processes)
- Suitable for smaller deployments or development environments
- Default concurrency: 20 threads (increased to handle combined workload)
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
- Better isolation and scalability
- Can scale individual workers independently based on workload
- Suitable for production deployments with higher load
The deployment mode affects:
- **Backend**: Worker processes spawned by supervisord or dev scripts
- **Helm**: Which Kubernetes deployments are created
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
#### Key Features
- **Thread-based Workers**: All workers use thread pools (not processes) for stability

View File

@@ -0,0 +1,119 @@
"""add_voice_provider_and_user_voice_prefs
Revision ID: 93a2e195e25c
Revises: a3b8d9e2f1c4
Create Date: 2026-02-23 15:16:39.507304
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "93a2e195e25c"
down_revision = "a3b8d9e2f1c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create voice_provider table
op.create_table(
"voice_provider",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("name", sa.String(), unique=True, nullable=False),
sa.Column("provider_type", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
sa.Column("stt_model", sa.String(), nullable=True),
sa.Column("tts_model", sa.String(), nullable=True),
sa.Column("default_voice", sa.String(), nullable=True),
sa.Column(
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
nullable=False,
),
)
# Add partial unique indexes to enforce only one default STT/TTS provider
op.execute(
"""
CREATE UNIQUE INDEX ix_voice_provider_one_default_stt
ON voice_provider (is_default_stt)
WHERE is_default_stt = true
"""
)
op.execute(
"""
CREATE UNIQUE INDEX ix_voice_provider_one_default_tts
ON voice_provider (is_default_tts)
WHERE is_default_tts = true
"""
)
# Add voice preference columns to user table
op.add_column(
"user",
sa.Column(
"voice_auto_send",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_auto_playback",
sa.Boolean(),
default=False,
nullable=False,
server_default="false",
),
)
op.add_column(
"user",
sa.Column(
"voice_playback_speed",
sa.Float(),
default=1.0,
nullable=False,
server_default="1.0",
),
)
op.add_column(
"user",
sa.Column("preferred_voice", sa.String(), nullable=True),
)
def downgrade() -> None:
# Remove user voice preference columns
op.drop_column("user", "preferred_voice")
op.drop_column("user", "voice_playback_speed")
op.drop_column("user", "voice_auto_playback")
op.drop_column("user", "voice_auto_send")
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_tts")
op.execute("DROP INDEX IF EXISTS ix_voice_provider_one_default_stt")
# Drop voice_provider table
op.drop_table("voice_provider")

View File

@@ -1,15 +0,0 @@
from onyx.background.celery.apps import app_base
from onyx.background.celery.apps.background import celery_app
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
"ee.onyx.background.celery.tasks.doc_permission_syncing",
"ee.onyx.background.celery.tasks.external_group_syncing",
"ee.onyx.background.celery.tasks.cleanup",
"ee.onyx.background.celery.tasks.tenant_provisioning",
"ee.onyx.background.celery.tasks.query_history",
]
)
)

View File

@@ -15,6 +15,7 @@ from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -471,7 +472,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
name=user_group.name,
time_last_modified_by_user=func.now(),
is_up_to_date=DISABLE_VECTOR_DB,
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -774,8 +777,7 @@ def update_user_group(
cc_pair_ids=user_group_update.cc_pair_ids,
)
# only needs to sync with Vespa if the cc_pairs have been updated
if cc_pairs_updated:
if cc_pairs_updated and not DISABLE_VECTOR_DB:
db_user_group.is_up_to_date = False
removed_users = db_session.scalars(

View File

@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -153,12 +152,9 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -223,6 +223,15 @@ def get_active_scim_token(
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
# Derive the IdP domain from the first synced user as a heuristic.
idp_domain: str | None = None
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
if mappings:
user = dal.get_user(mappings[0].user_id)
if user and "@" in user.email:
idp_domain = user.email.rsplit("@", 1)[1]
return ScimTokenResponse(
id=token.id,
name=token.name,
@@ -230,6 +239,7 @@ def get_active_scim_token(
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
idp_domain=idp_domain,
)

View File

@@ -365,6 +365,7 @@ class ScimTokenResponse(BaseModel):
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
idp_domain: str | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):

View File

@@ -5,6 +5,8 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.user_group import add_users_to_user_group
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import fetch_user_groups
from ee.onyx.db.user_group import fetch_user_groups_for_user
from ee.onyx.db.user_group import insert_user_group
@@ -20,6 +22,7 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -153,3 +156,8 @@ def delete_user_group(
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
if DISABLE_VECTOR_DB:
user_group = fetch_user_group(db_session, user_group_id)
if user_group:
db_delete_user_group(db_session, user_group)

View File

@@ -28,6 +28,7 @@ from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi import WebSocket
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
@@ -1599,6 +1600,91 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
return user
async def _get_user_from_token_data(token_data: dict) -> User | None:
"""Shared logic: token data dict → User object.
Args:
token_data: Decoded token data containing 'sub' (user ID).
Returns:
User object if found and active, None otherwise.
"""
user_id = token_data.get("sub")
if not user_id:
return None
try:
user_uuid = uuid.UUID(user_id)
except ValueError:
return None
async with get_async_session_context_manager() as async_db_session:
user = await async_db_session.get(User, user_uuid)
if user is None or not user.is_active:
return None
return user
async def current_user_from_websocket(
_websocket: WebSocket,
token: str = Query(..., description="WebSocket authentication token"),
) -> User:
"""
WebSocket authentication dependency using query parameter.
Validates the WS token from query param and returns the User.
Raises BasicAuthenticationError if authentication fails.
The token must be obtained from POST /voice/ws-token before connecting.
Tokens are single-use and expire after 60 seconds.
Usage:
1. POST /voice/ws-token -> {"token": "xxx"}
2. Connect to ws://host/path?token=xxx
This applies the same auth checks as current_user() for HTTP endpoints.
"""
from onyx.redis.redis_pool import retrieve_ws_token_data
# Validate WS token in Redis (single-use, deleted after retrieval)
try:
token_data = await retrieve_ws_token_data(token)
if token_data is None:
raise BasicAuthenticationError(
detail="Access denied. Invalid or expired authentication token."
)
except BasicAuthenticationError:
raise
except Exception as e:
logger.error(f"WS auth: error during token validation: {e}")
raise BasicAuthenticationError(
detail=f"Authentication verification failed: {str(e)}"
) from e
# Get user from token data
user = await _get_user_from_token_data(token_data)
if user is None:
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
raise BasicAuthenticationError(
detail="Access denied. User not found or inactive."
)
logger.info(f"WS auth: user found: {user.email}")
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
user = await double_check_user(user)
logger.info(f"WS auth: user verified: {user.email}, role={user.role}")
# Block LIMITED users (same as current_user)
if user.role == UserRole.LIMITED:
logger.warning(f"WS auth: user {user.email} has LIMITED role")
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
logger.info(f"WS auth: authentication successful for {user.email}")
return user
def get_default_admin_user_emails_() -> list[str]:
# No default seeding available for Onyx MIT
return []

View File

@@ -1,142 +0,0 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.background")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received for consolidated background worker.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
# Initialize Vespa httpx pool (needed for light worker tasks)
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
app_base.filter_task_modules(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.opensearch_migration",
# Docprocessing worker tasks
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)
)

View File

@@ -39,9 +39,13 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector."""
"""Result of extracting document IDs and hierarchy nodes from a connector.
doc_ids: set[str]
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
"""
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
@@ -93,30 +97,37 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
return None
class BatchResult(BaseModel):
raw_id_to_parent: dict[str, str | None]
hierarchy_nodes: list[HierarchyNode]
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
) -> BatchResult:
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID set so that failed-to-retrieve documents are not accidentally pruned.
ID dict so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: set[str] = set()
ids: dict[str, str | None] = {}
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
ids.add(item.raw_node_id)
if item.raw_node_id not in ids:
ids[item.raw_node_id] = None
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids.add(failed_id)
ids[failed_id] = None
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
ids.add(item.id)
return ids, hierarchy_nodes
parent_raw = getattr(item, "parent_hierarchy_raw_node_id", None)
ids[item.id] = parent_raw
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
def extract_ids_from_runnable_connector(
@@ -132,7 +143,7 @@ def extract_ids_from_runnable_connector(
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
all_raw_id_to_parent: dict[str, str | None] = {}
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
@@ -177,15 +188,20 @@ def extract_ids_from_runnable_connector(
"extract_ids_from_runnable_connector: Stop signal detected"
)
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
batch_result = _extract_from_batch(doc_list)
batch_ids = batch_result.raw_id_to_parent
batch_nodes = batch_result.hierarchy_nodes
doc_batch_processing_func(batch_ids)
for k, v in batch_ids.items():
if v is not None or k not in all_raw_id_to_parent:
all_raw_id_to_parent[k] = v
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return SlimConnectorExtractionResult(
doc_ids=all_connector_doc_ids,
raw_id_to_parent=all_raw_id_to_parent,
hierarchy_nodes=all_hierarchy_nodes,
)

View File

@@ -1,23 +0,0 @@
import onyx.background.celery.configs.base as shared_config
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
worker_pool = "threads"
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
# This allows the worker to prefetch multiple tasks per thread
worker_prefetch_multiplier = 4

View File

@@ -29,6 +29,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -47,6 +48,8 @@ from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
@@ -57,6 +60,8 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
@@ -113,6 +118,38 @@ class PruneCallback(IndexingCallbackBase):
super().progress(tag, amount)
def _resolve_and_update_document_parents(
db_session: Session,
redis_client: Redis,
source: DocumentSource,
raw_id_to_parent: dict[str, str | None],
) -> None:
"""Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for
each document and bulk-update the DB. Mirrors the resolution logic in
run_docfetching.py."""
source_node_id = get_source_node_id_from_cache(redis_client, db_session, source)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent_id in raw_id_to_parent.items():
if raw_parent_id is None:
continue
node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id)
resolved[doc_id] = node_id if found else source_node_id
if not resolved:
return
update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
task_logger.info(
f"Pruning: resolved and updated parent hierarchy for "
f"{len(resolved)} documents (source={source.value})"
)
"""Jobs / utils for kicking off pruning tasks."""
@@ -535,22 +572,22 @@ def connector_pruning_generator_task(
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.doc_ids
all_connector_doc_ids = extraction_result.raw_id_to_parent
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
source = cc_pair.connector.source
redis_client = get_redis_client(tenant_id=tenant_id)
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
ensure_source_node_exists(redis_client, db_session, source)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=cc_pair.connector.source,
source=source,
commit=True,
is_connector_public=is_connector_public,
)
@@ -561,7 +598,7 @@ def connector_pruning_generator_task(
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=cc_pair.connector.source,
source=source,
entries=cache_entries,
)
@@ -570,6 +607,26 @@ def connector_pruning_generator_task(
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
ensure_source_node_exists(redis_client, db_session, source)
# Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id
# and bulk-update documents, mirroring the docfetching resolution
_resolve_and_update_document_parents(
db_session=db_session,
redis_client=redis_client,
source=source,
raw_id_to_parent=all_connector_doc_ids,
)
# Link hierarchy nodes to documents for sources where pages can be
# both hierarchy nodes AND documents (e.g. Notion, Confluence)
all_doc_id_list = list(all_connector_doc_ids.keys())
link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=all_doc_id_list,
source=source,
commit=True,
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
@@ -581,7 +638,9 @@ def connector_pruning_generator_task(
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
doc_ids_to_remove = list(
all_indexed_document_ids - all_connector_doc_ids.keys()
)
task_logger.info(
"Pruning set collected: "

View File

@@ -1,10 +0,0 @@
from celery import Celery
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app: Celery = fetch_versioned_implementation(
"onyx.background.celery.apps.background",
"celery_app",
)

View File

@@ -495,14 +495,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
)
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_BACKGROUND_CONCURRENCY") or 20
)
# Individual worker concurrency settings (used when USE_LIGHTWEIGHT_BACKGROUND_WORKER is False or on Kuberenetes deployments)
# Individual worker concurrency settings
CELERY_WORKER_HEAVY_CONCURRENCY = int(
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
)

View File

@@ -84,7 +84,6 @@ POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing"
POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (

View File

@@ -943,6 +943,9 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id(
page
),
)
)
@@ -992,6 +995,7 @@ class ConfluenceConnector(
if include_permissions
else None
),
parent_hierarchy_raw_node_id=page_id,
)
)

View File

@@ -781,4 +781,5 @@ def build_slim_document(
return SlimDocument(
id=onyx_document_id_from_drive_file(file),
external_access=external_access,
parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0],
)

View File

@@ -902,6 +902,11 @@ class JiraConnector(
external_access=self._get_project_permissions(
project_key, add_prefix=False
),
parent_hierarchy_raw_node_id=(
self._get_parent_hierarchy_raw_node_id(issue, project_key)
if project_key
else None
),
)
)
current_offset += 1

View File

@@ -385,6 +385,7 @@ class IndexingDocument(Document):
class SlimDocument(BaseModel):
id: str
external_access: ExternalAccess | None = None
parent_hierarchy_raw_node_id: str | None = None
class HierarchyNode(BaseModel):

View File

@@ -772,6 +772,7 @@ def _convert_driveitem_to_slim_document(
drive_name: str,
ctx: ClientContext,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
if driveitem.id is None:
raise ValueError("DriveItem ID is required")
@@ -787,11 +788,15 @@ def _convert_driveitem_to_slim_document(
return SlimDocument(
id=driveitem.id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
def _convert_sitepage_to_slim_document(
site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient
site_page: dict[str, Any],
ctx: ClientContext | None,
graph_client: GraphClient,
parent_hierarchy_raw_node_id: str | None = None,
) -> SlimDocument:
"""Convert a SharePoint site page to a SlimDocument object."""
if site_page.get("id") is None:
@@ -808,6 +813,7 @@ def _convert_sitepage_to_slim_document(
return SlimDocument(
id=id,
external_access=external_access,
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
)
@@ -1594,12 +1600,22 @@ class SharepointConnector(
)
)
parent_hierarchy_url: str | None = None
if drive_web_url:
parent_hierarchy_url = self._get_parent_hierarchy_url(
site_url, drive_web_url, drive_name, driveitem
)
try:
logger.debug(f"Processing: {driveitem.web_url}")
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_driveitem_to_slim_document(
driveitem, drive_name, ctx, self.graph_client
driveitem,
drive_name,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=parent_hierarchy_url,
)
)
except Exception as e:
@@ -1619,7 +1635,10 @@ class SharepointConnector(
ctx = self._create_rest_client_context(site_descriptor.url)
doc_batch.append(
_convert_sitepage_to_slim_document(
site_page, ctx, self.graph_client
site_page,
ctx,
self.graph_client,
parent_hierarchy_raw_node_id=site_descriptor.url,
)
)
if len(doc_batch) >= SLIM_BATCH_SIZE:

View File

@@ -565,6 +565,7 @@ def _get_all_doc_ids(
channel_id=channel_id, thread_ts=message["ts"]
),
external_access=external_access,
parent_hierarchy_raw_node_id=channel_id,
)
)

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
@@ -246,6 +247,7 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
is_up_to_date=DISABLE_VECTOR_DB,
time_last_modified_by_user=func.now(),
)
db_session.add(new_document_set_row)
@@ -336,7 +338,8 @@ def update_document_set(
)
document_set_row.description = document_set_update_request.description
document_set_row.is_up_to_date = False
if not DISABLE_VECTOR_DB:
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(

View File

@@ -1,5 +1,7 @@
"""CRUD operations for HierarchyNode."""
from collections import defaultdict
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -525,6 +527,53 @@ def get_document_parent_hierarchy_node_ids(
return {doc_id: parent_id for doc_id, parent_id in results}
def update_document_parent_hierarchy_nodes(
db_session: Session,
doc_parent_map: dict[str, int | None],
commit: bool = True,
) -> int:
"""Bulk-update Document.parent_hierarchy_node_id for multiple documents.
Only updates rows whose current value differs from the desired value to
avoid unnecessary writes.
Args:
db_session: SQLAlchemy session
doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id
commit: Whether to commit the transaction
Returns:
Number of documents actually updated
"""
if not doc_parent_map:
return 0
doc_ids = list(doc_parent_map.keys())
existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids)
by_parent: dict[int | None, list[str]] = defaultdict(list)
for doc_id, desired_parent_id in doc_parent_map.items():
current = existing.get(doc_id)
if current == desired_parent_id or doc_id not in existing:
continue
by_parent[desired_parent_id].append(doc_id)
updated = 0
for desired_parent_id, ids in by_parent.items():
db_session.query(Document).filter(Document.id.in_(ids)).update(
{Document.parent_hierarchy_node_id: desired_parent_id},
synchronize_session=False,
)
updated += len(ids)
if commit:
db_session.commit()
elif updated:
db_session.flush()
return updated
def update_hierarchy_node_permissions(
db_session: Session,
raw_node_id: str,

View File

@@ -284,6 +284,12 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# Voice preferences
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
preferred_voice: Mapped[str | None] = mapped_column(String, nullable=True)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
@@ -2964,6 +2970,63 @@ class ImageGenerationConfig(Base):
)
class VoiceProvider(Base):
"""Configuration for voice services (STT and TTS)."""
__tablename__ = "voice_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider_type: Mapped[str] = mapped_column(
String
) # "openai", "azure", "elevenlabs"
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Model/voice configuration
stt_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "whisper-1"
tts_model: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "tts-1", "tts-1-hd"
default_voice: Mapped[str | None] = mapped_column(
String, nullable=True
) # e.g., "alloy", "echo"
# STT and TTS can use different providers - only one provider per type
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
# Enforce only one default STT provider and one default TTS provider at DB level
__table_args__ = (
Index(
"ix_voice_provider_one_default_stt",
"is_default_stt",
unique=True,
postgresql_where=(is_default_stt == True), # noqa: E712
),
Index(
"ix_voice_provider_one_default_tts",
"is_default_tts",
unique=True,
postgresql_where=(is_default_tts == True), # noqa: E712
),
)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"

View File

@@ -129,7 +129,7 @@ def get_current_search_settings(db_session: Session) -> SearchSettings:
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified, DB is not in a valid state")
raise RuntimeError("No search settings specified; DB is not in a valid state.")
return latest_settings

224
backend/onyx/db/voice.py Normal file
View File

@@ -0,0 +1,224 @@
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import VoiceProvider
# Sentinel value to distinguish "not provided" from "explicitly set to None"
_UNSET: Any = object()
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
"""Fetch all voice providers."""
return list(
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
)
def fetch_voice_provider_by_id(
db_session: Session, provider_id: int
) -> VoiceProvider | None:
"""Fetch a voice provider by ID."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.id == provider_id)
)
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default STT provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
)
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
"""Fetch the default TTS provider."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
)
def fetch_voice_provider_by_type(
db_session: Session, provider_type: str
) -> VoiceProvider | None:
"""Fetch a voice provider by type."""
return db_session.scalar(
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
)
def upsert_voice_provider(
*,
db_session: Session,
provider_id: int | None,
name: str,
provider_type: str,
api_key: str | None,
api_key_changed: bool,
api_base: str | None = None,
custom_config: dict[str, Any] | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
activate_stt: bool = False,
activate_tts: bool = False,
) -> VoiceProvider:
"""Create or update a voice provider."""
provider: VoiceProvider | None = None
if provider_id is not None:
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
else:
provider = VoiceProvider()
db_session.add(provider)
# Apply updates
provider.name = name
provider.provider_type = provider_type
provider.api_base = api_base
provider.custom_config = custom_config
provider.stt_model = stt_model
provider.tts_model = tts_model
provider.default_voice = default_voice
# Only update API key if explicitly changed or if provider has no key
if api_key_changed or provider.api_key is None:
provider.api_key = api_key # type: ignore[assignment]
db_session.flush()
if activate_stt:
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
if activate_tts:
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
db_session.refresh(provider)
return provider
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
"""Delete a voice provider by ID."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider:
db_session.delete(provider)
db_session.flush()
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Set a voice provider as the default STT provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other STT providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_stt.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_stt=False)
)
# Activate this provider
provider.is_default_stt = True
db_session.flush()
db_session.refresh(provider)
return provider
def set_default_tts_provider(
*, db_session: Session, provider_id: int, tts_model: str | None = None
) -> VoiceProvider:
"""Set a voice provider as the default TTS provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
# Deactivate all other TTS providers
db_session.execute(
update(VoiceProvider)
.where(
VoiceProvider.is_default_tts.is_(True),
VoiceProvider.id != provider_id,
)
.values(is_default_tts=False)
)
# Activate this provider
provider.is_default_tts = True
# Update the TTS model if specified
if tts_model is not None:
provider.tts_model = tts_model
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default STT status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_stt = False
db_session.flush()
db_session.refresh(provider)
return provider
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
"""Remove the default TTS status from a voice provider."""
provider = fetch_voice_provider_by_id(db_session, provider_id)
if provider is None:
raise ValueError(f"No voice provider with id {provider_id} exists.")
provider.is_default_tts = False
db_session.flush()
db_session.refresh(provider)
return provider
# User voice preferences
def update_user_voice_settings(
db_session: Session,
user_id: UUID,
auto_send: bool | None = None,
auto_playback: bool | None = None,
playback_speed: float | None = None,
preferred_voice: str | None = _UNSET,
) -> None:
"""Update user's voice settings.
For most fields, None means "don't update this field".
For preferred_voice, use None to clear the preference (reset to default),
or omit the parameter to leave it unchanged.
"""
values: dict[str, Any] = {}
if auto_send is not None:
values["voice_auto_send"] = auto_send
if auto_playback is not None:
values["voice_auto_playback"] = auto_playback
if playback_speed is not None:
values["voice_playback_speed"] = max(0.5, min(2.0, playback_speed))
# preferred_voice uses sentinel: _UNSET means "don't update", None means "clear"
if preferred_voice is not _UNSET:
values["preferred_voice"] = preferred_voice
if values:
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
db_session.flush()

View File

@@ -32,9 +32,6 @@ def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
Determines whether to enable multipass and large chunks by examining
the current search settings and the embedder configuration.
"""
if not search_settings:
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
multipass = should_use_multipass(search_settings)
enable_large_chunks = SearchSettings.can_use_large_chunks(
multipass, search_settings.model_name, search_settings.provider_type

View File

@@ -26,11 +26,10 @@ def get_default_document_index(
To be used for retrieval only. Indexing should be done through both indices
until Vespa is deprecated.
Pre-existing docstring for this function, although secondary indices are not
currently supported:
Primary index is the index that is used for querying/updating etc. Secondary
index is for when both the currently used index and the upcoming index both
need to be updated, updates are applied to both indices.
need to be updated. Updates are applied to both indices.
WARNING: In that case, get_all_document_indices should be used.
"""
if DISABLE_VECTOR_DB:
return DisabledDocumentIndex(
@@ -51,11 +50,26 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
@@ -86,8 +100,7 @@ def get_all_document_indices(
Used for indexing only. Until Vespa is deprecated we will index into both
document indices. Retrieval is done through only one index however.
Large chunks and secondary indices are not currently supported so we
hardcode appropriate values.
Large chunks are not currently supported so we hardcode appropriate values.
NOTE: Make sure the Vespa index object is returned first. In the rare event
that there is some conflict between indexing and the migration task, it is
@@ -123,13 +136,36 @@ def get_all_document_indices(
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
secondary_indexing_setting = (
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None
)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
secondary_index_name=(
secondary_search_settings.index_name
if secondary_search_settings
else None
),
secondary_embedding_dim=(
secondary_indexing_setting.final_embedding_dim
if secondary_indexing_setting
else None
),
secondary_embedding_precision=(
secondary_indexing_setting.embedding_precision
if secondary_indexing_setting
else None
),
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=(
secondary_search_settings.large_chunks_enabled
if secondary_search_settings
else None
),
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)

View File

@@ -271,6 +271,9 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
secondary_embedding_dim: int | None,
secondary_embedding_precision: EmbeddingPrecision | None,
# NOTE: We do not support large chunks right now.
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
multitenant: bool = False,
@@ -286,12 +289,25 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
f"Expected {MULTI_TENANT}, got {multitenant}."
)
tenant_id = get_current_tenant_id()
tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant)
self._real_index = OpenSearchDocumentIndex(
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
tenant_state=tenant_state,
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
self._secondary_real_index: OpenSearchDocumentIndex | None = None
if self.secondary_index_name:
if secondary_embedding_dim is None or secondary_embedding_precision is None:
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
self._secondary_real_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=self.secondary_index_name,
embedding_dim=secondary_embedding_dim,
embedding_precision=secondary_embedding_precision,
)
@staticmethod
def register_multitenant_indices(
@@ -307,19 +323,38 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None, # noqa: ARG002
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
# Only handle primary index for now, ignore secondary.
return self._real_index.verify_and_create_index_if_necessary(
self._real_index.verify_and_create_index_if_necessary(
primary_embedding_dim, primary_embedding_precision
)
if self.secondary_index_name:
if (
secondary_index_embedding_dim is None
or secondary_index_embedding_precision is None
):
raise ValueError(
"Bug: Secondary index embedding dimension and precision are not set."
)
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.verify_and_create_index_if_necessary(
secondary_index_embedding_dim, secondary_index_embedding_precision
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
# Convert IndexBatchParams to IndexingMetadata.
chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {}
for doc_id in index_batch_params.doc_id_to_new_chunk_cnt:
@@ -351,7 +386,20 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
tenant_id: str, # noqa: ARG002
chunk_count: int | None,
) -> int:
return self._real_index.delete(doc_id, chunk_count)
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
total_chunks_deleted = self._real_index.delete(doc_id, chunk_count)
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
total_chunks_deleted += self._secondary_real_index.delete(
doc_id, chunk_count
)
return total_chunks_deleted
def update_single(
self,
@@ -362,6 +410,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> None:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
f"Tried to update document {doc_id} with no updated fields or user fields."
@@ -392,6 +445,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
try:
self._real_index.update([update_request])
if self.secondary_index_name:
assert (
self._secondary_real_index is not None
), "Bug: Secondary index is not initialized."
self._secondary_real_index.update([update_request])
except NotFoundError:
logger.exception(
f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. "

View File

@@ -465,6 +465,12 @@ class VespaIndex(DocumentIndex):
chunks: list[DocMetadataAwareIndexChunk],
index_batch_params: IndexBatchParams,
) -> set[OldDocumentInsertionRecord]:
"""
NOTE: Do NOT consider the secondary index here. A separate indexing
pipeline will be responsible for indexing to the secondary index. This
design is not ideal and we should reconsider this when revamping index
swapping.
"""
if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len(
index_batch_params.doc_id_to_new_chunk_cnt
):
@@ -659,6 +665,10 @@ class VespaIndex(DocumentIndex):
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for updating chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
if fields is None and user_fields is None:
logger.warning(
@@ -679,13 +689,6 @@ class VespaIndex(DocumentIndex):
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
@@ -705,7 +708,20 @@ class VespaIndex(DocumentIndex):
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
vespa_document_index.update([update_request])
def delete_single(
self,
@@ -714,6 +730,11 @@ class VespaIndex(DocumentIndex):
tenant_id: str,
chunk_count: int | None,
) -> int:
"""
NOTE: Remember to handle the secondary index here. There is no separate
pipeline for deleting chunks in the secondary index. This design is not
ideal and we should reconsider this when revamping index swapping.
"""
tenant_state = TenantState(
tenant_id=get_current_tenant_id(),
multitenant=MULTI_TENANT,
@@ -726,13 +747,25 @@ class VespaIndex(DocumentIndex):
raise ValueError(
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
)
vespa_document_index = VespaDocumentIndex(
index_name=self.index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.large_chunks_enabled,
httpx_client=self.httpx_client,
)
return vespa_document_index.delete(document_id=doc_id, chunk_count=chunk_count)
indices = [self.index_name]
if self.secondary_index_name:
indices.append(self.secondary_index_name)
total_chunks_deleted = 0
for index_name in indices:
vespa_document_index = VespaDocumentIndex(
index_name=index_name,
tenant_state=tenant_state,
large_chunks_enabled=self.index_to_large_chunks_enabled.get(
index_name, False
),
httpx_client=self.httpx_client,
)
total_chunks_deleted += vespa_document_index.delete(
document_id=doc_id, chunk_count=chunk_count
)
return total_chunks_deleted
def id_based_retrieval(
self,

View File

@@ -119,6 +119,9 @@ from onyx.server.manage.opensearch_migration.api import (
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.voice.api import admin_router as voice_admin_router
from onyx.server.manage.voice.user_api import router as voice_router
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
@@ -497,6 +500,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, web_search_router)
include_router_with_global_prefix_prepended(application, web_search_admin_router)
include_router_with_global_prefix_prepended(application, voice_admin_router)
include_router_with_global_prefix_prepended(application, voice_router)
include_router_with_global_prefix_prepended(application, voice_websocket_router)
include_router_with_global_prefix_prepended(
application, opensearch_migration_admin_router
)

View File

@@ -130,7 +130,7 @@ def format_slack_message(message: str | None) -> str:
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
@@ -146,6 +146,11 @@ class SlackRenderer(HTMLRenderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
@@ -218,5 +223,48 @@ class SlackRenderer(HTMLRenderer):
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n\n"

View File

@@ -419,12 +419,15 @@ async def get_async_redis_connection() -> aioredis.Redis:
return _async_redis_connection
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
async def retrieve_auth_token_data(token: str) -> dict | None:
"""Validate auth token against Redis and return token data.
Args:
token: The raw authentication token string.
Returns:
Token data dict if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_AUTH_KEY_PREFIX + token
@@ -439,12 +442,67 @@ async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
logger.error("Error decoding token data from Redis")
return None
except Exception as e:
logger.error(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
raise ValueError(
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
)
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
return await retrieve_auth_token_data(token)
# WebSocket token prefix (separate from regular auth tokens)
REDIS_WS_TOKEN_PREFIX = "ws_token:"
# WebSocket tokens expire after 60 seconds
WS_TOKEN_TTL_SECONDS = 60
async def store_ws_token(token: str, user_id: str) -> None:
"""Store a short-lived WebSocket authentication token in Redis.
Args:
token: The generated WS token.
user_id: The user ID to associate with this token.
"""
redis = await get_async_redis_connection()
redis_key = REDIS_WS_TOKEN_PREFIX + token
token_data = json.dumps({"sub": user_id})
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
async def retrieve_ws_token_data(token: str) -> dict | None:
"""Validate a WebSocket token and return the token data.
This uses GETDEL for atomic get-and-delete to prevent race conditions
where the same token could be used twice.
Args:
token: The WS token to validate.
Returns:
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
"""
try:
redis = await get_async_redis_connection()
redis_key = REDIS_WS_TOKEN_PREFIX + token
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
token_data_str = await redis.getdel(redis_key)
if not token_data_str:
return None
return json.loads(token_data_str)
except json.JSONDecodeError:
logger.error("Error decoding WS token data from Redis")
return None
except Exception as e:
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
return None
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:

View File

@@ -9,6 +9,7 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.auth.users import current_user_from_websocket
from onyx.auth.users import current_user_with_expired_token
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -129,6 +130,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accessible_user
or depends_fn == current_user_from_websocket
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
or depends_fn == verify_scim_token

View File

@@ -7424,9 +7424,9 @@
}
},
"node_modules/hono": {
"version": "4.11.7",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"version": "4.12.5",
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
"license": "MIT",
"engines": {
"node": ">=16.9.0"

View File

@@ -11,6 +11,7 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
from onyx.db.document_set import delete_document_set as db_delete_document_set
from onyx.db.document_set import fetch_all_document_sets_for_user
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import insert_document_set
@@ -142,7 +143,10 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
if not DISABLE_VECTOR_DB:
if DISABLE_VECTOR_DB:
db_session.refresh(document_set)
db_delete_document_set(document_set, db_session)
else:
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -14,8 +15,6 @@ from onyx.db.llm import remove_llm_provider__no_commit
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import ModelConfiguration
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.factory import get_image_generation_provider
from onyx.image_gen.factory import validate_credentials
@@ -75,9 +74,9 @@ def _build_llm_provider_request(
# Clone mode: Only use API key from source provider
source_provider = db_session.get(LLMProviderModel, source_llm_provider_id)
if not source_provider:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Source LLM provider with id {source_llm_provider_id} not found",
raise HTTPException(
status_code=404,
detail=f"Source LLM provider with id {source_llm_provider_id} not found",
)
_validate_llm_provider_change(
@@ -111,9 +110,9 @@ def _build_llm_provider_request(
)
if not provider:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No provider or source llm provided",
raise HTTPException(
status_code=400,
detail="No provider or source llm provided",
)
credentials = ImageGenerationProviderCredentials(
@@ -125,9 +124,9 @@ def _build_llm_provider_request(
)
if not validate_credentials(provider, credentials):
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
f"Incorrect credentials for {provider}",
raise HTTPException(
status_code=400,
detail=f"Incorrect credentials for {provider}",
)
return LLMProviderUpsertRequest(
@@ -216,9 +215,9 @@ def test_image_generation(
LLMProviderModel, test_request.source_llm_provider_id
)
if not source_provider:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
raise HTTPException(
status_code=404,
detail=f"Source LLM provider with id {test_request.source_llm_provider_id} not found",
)
_validate_llm_provider_change(
@@ -237,9 +236,9 @@ def test_image_generation(
provider = source_provider.provider
if provider is None:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"No provider or source llm provided",
raise HTTPException(
status_code=400,
detail="No provider or source llm provided",
)
try:
@@ -258,14 +257,14 @@ def test_image_generation(
),
)
except ValueError:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"Invalid image generation provider: {provider}",
raise HTTPException(
status_code=404,
detail=f"Invalid image generation provider: {provider}",
)
except ImageProviderCredentialsError:
raise OnyxError(
OnyxErrorCode.CREDENTIAL_INVALID,
"Invalid image generation credentials",
raise HTTPException(
status_code=401,
detail="Invalid image generation credentials",
)
quality = _get_test_quality_for_model(test_request.model_name)
@@ -277,15 +276,15 @@ def test_image_generation(
n=1,
quality=quality,
)
except OnyxError:
except HTTPException:
raise
except Exception as e:
# Log only exception type to avoid exposing sensitive data
# (LiteLLM errors may contain URLs with API keys or auth tokens)
logger.warning(f"Image generation test failed: {type(e).__name__}")
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
f"Image generation test failed: {type(e).__name__}",
raise HTTPException(
status_code=400,
detail=f"Image generation test failed: {type(e).__name__}",
)
@@ -310,9 +309,9 @@ def create_config(
db_session, config_create.image_provider_id
)
if existing_config:
raise OnyxError(
OnyxErrorCode.DUPLICATE_RESOURCE,
f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
raise HTTPException(
status_code=400,
detail=f"ImageGenerationConfig with image_provider_id '{config_create.image_provider_id}' already exists",
)
try:
@@ -346,10 +345,10 @@ def create_config(
db_session.commit()
db_session.refresh(config)
return ImageGenerationConfigView.from_model(config)
except OnyxError:
except HTTPException:
raise
except Exception as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
@admin_router.get("/config")
@@ -374,9 +373,9 @@ def get_config_credentials(
"""
config = get_image_generation_config(db_session, image_provider_id)
if not config:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
return ImageGenerationCredentials.from_model(config)
@@ -402,9 +401,9 @@ def update_config(
# 1. Get existing config
existing_config = get_image_generation_config(db_session, image_provider_id)
if not existing_config:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
old_llm_provider_id = existing_config.model_configuration.llm_provider_id
@@ -473,10 +472,10 @@ def update_config(
db_session.refresh(existing_config)
return ImageGenerationConfigView.from_model(existing_config)
except OnyxError:
except HTTPException:
raise
except Exception as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/config/{image_provider_id}")
@@ -490,9 +489,9 @@ def delete_config(
# Get the config first to find the associated LLM provider
existing_config = get_image_generation_config(db_session, image_provider_id)
if not existing_config:
raise OnyxError(
OnyxErrorCode.NOT_FOUND,
f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
raise HTTPException(
status_code=404,
detail=f"ImageGenerationConfig with image_provider_id {image_provider_id} not found",
)
llm_provider_id = existing_config.model_configuration.llm_provider_id
@@ -504,10 +503,10 @@ def delete_config(
remove_llm_provider__no_commit(db_session, llm_provider_id)
db_session.commit()
except OnyxError:
except HTTPException:
raise
except ValueError as e:
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/config/{image_provider_id}/default")
@@ -520,7 +519,7 @@ def set_config_as_default(
try:
set_default_image_generation_config(db_session, image_provider_id)
except ValueError as e:
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
raise HTTPException(status_code=404, detail=str(e))
@admin_router.delete("/config/{image_provider_id}/default")
@@ -533,4 +532,4 @@ def unset_config_as_default(
try:
unset_default_image_generation_config(db_session, image_provider_id)
except ValueError as e:
raise OnyxError(OnyxErrorCode.NOT_FOUND, str(e))
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -85,6 +85,12 @@ class UserPreferences(BaseModel):
chat_background: str | None = None
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
# Voice preferences
voice_auto_send: bool | None = None
voice_auto_playback: bool | None = None
voice_playback_speed: float | None = None
preferred_voice: str | None = None
# controls which tools are enabled for the user for a specific assistant
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
@@ -164,6 +170,10 @@ class UserInfo(BaseModel):
theme_preference=user.theme_preference,
chat_background=user.chat_background,
default_app_mode=user.default_app_mode,
voice_auto_send=user.voice_auto_send,
voice_auto_playback=user.voice_auto_playback,
voice_playback_speed=user.voice_playback_speed,
preferred_voice=user.preferred_voice,
assistant_specific_configs=assistant_specific_configs,
)
),
@@ -240,6 +250,13 @@ class ChatBackgroundRequest(BaseModel):
chat_background: str | None
class VoiceSettingsUpdateRequest(BaseModel):
auto_send: bool | None = None
auto_playback: bool | None = None
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
preferred_voice: str | None = None
class PersonalizationUpdateRequest(BaseModel):
name: str | None = None
role: str | None = None

View File

@@ -1,11 +1,16 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.engine.sql_engine import get_session
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_existing_llm_provider
@@ -13,22 +18,25 @@ from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import update_no_default_contextual_rag_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import User
from onyx.db.search_settings import create_search_settings
from onyx.db.search_settings import delete_search_settings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_embedding_provider_from_provider_type
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_all_document_indices
from onyx.document_index.factory import get_default_document_index
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.unstructured import delete_unstructured_api_key
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import update_unstructured_api_key
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.server.manage.embedding.models import SearchSettingsDeleteRequest
from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.server.utils_vector_db import require_vector_db
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
@@ -41,110 +49,99 @@ def set_new_search_settings(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), # noqa: ARG001
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
# TODO(andrei): Re-enable.
# NOTE Enable integration external dependency tests in test_search_settings.py
# when this is reenabled. They are currently skipped
logger.error("Setting new search settings is temporarily disabled.")
raise OnyxError(
OnyxErrorCode.NOT_IMPLEMENTED,
"Setting new search settings is temporarily disabled.",
Creates a new SearchSettings row and cancels the previous secondary indexing
if any exists.
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments.
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider.
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
validate_contextual_rag_model(
provider_name=search_settings_new.contextual_rag_llm_provider,
model_name=search_settings_new.contextual_rag_llm_name,
db_session=db_session,
)
# if search_settings_new.index_name:
# logger.warning("Index name was specified by request, this is not suggested")
# # Disallow contextual RAG for cloud deployments
# if MULTI_TENANT and search_settings_new.enable_contextual_rag:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Contextual RAG disabled in Onyx Cloud",
# )
search_settings = get_current_search_settings(db_session)
# # Validate cloud provider exists or create new LiteLLM provider
# if search_settings_new.provider_type is not None:
# cloud_provider = get_embedding_provider_from_provider_type(
# db_session, provider_type=search_settings_new.provider_type
# )
if search_settings_new.index_name is None:
# We define index name here.
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
# if cloud_provider is None:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
# )
secondary_search_settings = get_secondary_search_settings(db_session)
# validate_contextual_rag_model(
# provider_name=search_settings_new.contextual_rag_llm_provider,
# model_name=search_settings_new.contextual_rag_llm_name,
# db_session=db_session,
# )
if secondary_search_settings:
# Cancel any background indexing jobs.
expire_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# search_settings = get_current_search_settings(db_session)
# Mark previous model as a past model directly.
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
# if search_settings_new.index_name is None:
# # We define index name here
# index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
# if (
# search_settings_new.model_name == search_settings.model_name
# and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
# ):
# index_name += ALT_INDEX_SUFFIX
# search_values = search_settings_new.model_dump()
# search_values["index_name"] = index_name
# new_search_settings_request = SavedSearchSettings(**search_values)
# else:
# new_search_settings_request = SavedSearchSettings(
# **search_settings_new.model_dump()
# )
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# secondary_search_settings = get_secondary_search_settings(db_session)
# Ensure the document indices have the new index immediately.
document_indices = get_all_document_indices(search_settings, new_search_settings)
for document_index in document_indices:
document_index.ensure_indices_exist(
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# if secondary_search_settings:
# # Cancel any background indexing jobs
# expire_index_attempts(
# search_settings_id=secondary_search_settings.id, db_session=db_session
# )
# Pause index attempts for the currently in-use index to preserve resources.
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
# # Mark previous model as a past model directly
# update_search_settings_status(
# search_settings=secondary_search_settings,
# new_status=IndexModelStatus.PAST,
# db_session=db_session,
# )
# new_search_settings = create_search_settings(
# search_settings=new_search_settings_request, db_session=db_session
# )
# # Ensure Vespa has the new index immediately
# get_multipass_config(search_settings)
# get_multipass_config(new_search_settings)
# document_index = get_default_document_index(
# search_settings, new_search_settings, db_session
# )
# document_index.ensure_indices_exist(
# primary_embedding_dim=search_settings.final_embedding_dim,
# primary_embedding_precision=search_settings.embedding_precision,
# secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
# secondary_index_embedding_precision=new_search_settings.embedding_precision,
# )
# # Pause index attempts for the currently in use index to preserve resources
# if DISABLE_INDEX_UPDATE_ON_SWAP:
# expire_index_attempts(
# search_settings_id=search_settings.id, db_session=db_session
# )
# for cc_pair in get_connector_credential_pairs(db_session):
# resync_cc_pair(
# cc_pair=cc_pair,
# search_settings_id=new_search_settings.id,
# db_session=db_session,
# )
# db_session.commit()
# return IdReturn(id=new_search_settings.id)
db_session.commit()
return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding", dependencies=[Depends(require_vector_db)])
@@ -191,7 +188,7 @@ def delete_search_settings_endpoint(
search_settings_id=deletion_request.search_settings_id,
)
except ValueError as e:
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.get("/get-current-search-settings")
@@ -241,9 +238,9 @@ def update_saved_search_settings(
) -> None:
# Disallow contextual RAG for cloud deployments
if MULTI_TENANT and search_settings.enable_contextual_rag:
raise OnyxError(
OnyxErrorCode.VALIDATION_ERROR,
"Contextual RAG disabled in Onyx Cloud",
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
validate_contextual_rag_model(
@@ -297,7 +294,7 @@ def validate_contextual_rag_model(
model_name=model_name,
db_session=db_session,
):
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, error_msg)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
def _validate_contextual_rag_model(

View File

@@ -0,0 +1,282 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import User
from onyx.db.models import VoiceProvider
from onyx.db.voice import deactivate_stt_provider
from onyx.db.voice import deactivate_tts_provider
from onyx.db.voice import delete_voice_provider
from onyx.db.voice import fetch_voice_provider_by_id
from onyx.db.voice import fetch_voice_provider_by_type
from onyx.db.voice import fetch_voice_providers
from onyx.db.voice import set_default_stt_provider
from onyx.db.voice import set_default_tts_provider
from onyx.db.voice import upsert_voice_provider
from onyx.server.manage.voice.models import VoiceProviderTestRequest
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
from onyx.server.manage.voice.models import VoiceProviderView
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/voice")
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
"""Convert a VoiceProvider model to a VoiceProviderView."""
return VoiceProviderView(
id=provider.id,
name=provider.name,
provider_type=provider.provider_type,
is_default_stt=provider.is_default_stt,
is_default_tts=provider.is_default_tts,
stt_model=provider.stt_model,
tts_model=provider.tts_model,
default_voice=provider.default_voice,
has_api_key=bool(provider.api_key),
target_uri=provider.api_base, # api_base stores the target URI for Azure
)
@admin_router.get("/providers")
def list_voice_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VoiceProviderView]:
"""List all configured voice providers."""
providers = fetch_voice_providers(db_session)
return [_provider_to_view(provider) for provider in providers]
@admin_router.post("/providers")
def upsert_voice_provider_endpoint(
request: VoiceProviderUpsertRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Create or update a voice provider."""
api_key = request.api_key
api_key_changed = request.api_key_changed
# If llm_provider_id is specified, copy the API key from that LLM provider
if request.llm_provider_id is not None:
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
if llm_provider is None:
raise HTTPException(
status_code=404,
detail=f"LLM provider with id {request.llm_provider_id} not found.",
)
if llm_provider.api_key is None:
raise HTTPException(
status_code=400,
detail="Selected LLM provider has no API key configured.",
)
api_key = llm_provider.api_key.get_value(apply_mask=False)
api_key_changed = True
# Use target_uri if provided, otherwise fall back to api_base
api_base = request.target_uri or request.api_base
provider = upsert_voice_provider(
db_session=db_session,
provider_id=request.id,
name=request.name,
provider_type=request.provider_type,
api_key=api_key,
api_key_changed=api_key_changed,
api_base=api_base,
custom_config=request.custom_config,
stt_model=request.stt_model,
tts_model=request.tts_model,
default_voice=request.default_voice,
activate_stt=request.activate_stt,
activate_tts=request.activate_tts,
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.delete(
"/providers/{provider_id}", status_code=204, response_class=Response
)
def delete_voice_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> Response:
"""Delete a voice provider."""
delete_voice_provider(db_session, provider_id)
db_session.commit()
return Response(status_code=204)
@admin_router.post("/providers/{provider_id}/activate-stt")
def activate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default STT provider."""
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-stt")
def deactivate_stt_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default STT status from a voice provider."""
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/{provider_id}/activate-tts")
def activate_tts_provider_endpoint(
provider_id: int,
tts_model: str | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> VoiceProviderView:
"""Set a voice provider as the default TTS provider."""
provider = set_default_tts_provider(
db_session=db_session, provider_id=provider_id, tts_model=tts_model
)
db_session.commit()
return _provider_to_view(provider)
@admin_router.post("/providers/{provider_id}/deactivate-tts")
def deactivate_tts_provider_endpoint(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Remove the default TTS status from a voice provider."""
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
db_session.commit()
return {"status": "ok"}
@admin_router.post("/providers/test")
def test_voice_provider(
request: VoiceProviderTestRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Test a voice provider connection."""
api_key = request.api_key
if request.use_stored_key:
existing_provider = fetch_voice_provider_by_type(
db_session, request.provider_type
)
if existing_provider is None or not existing_provider.api_key:
raise HTTPException(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise HTTPException(
status_code=400,
detail="API key is required. Either provide api_key or set use_stored_key to true.",
)
# Use target_uri if provided, otherwise fall back to api_base
api_base = request.target_uri or request.api_base
# Create a temporary VoiceProvider for testing (not saved to DB)
temp_provider = VoiceProvider(
name="__test__",
provider_type=request.provider_type,
api_base=api_base,
custom_config=request.custom_config or {},
)
temp_provider.api_key = api_key # type: ignore[assignment]
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
# Test the provider by getting available voices (lightweight check)
try:
voices = provider.get_available_voices()
if not voices:
raise HTTPException(
status_code=400,
detail="Provider returned no available voices.",
)
except NotImplementedError:
# Provider not fully implemented yet (Azure, ElevenLabs placeholders)
pass
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Connection test failed: {str(e)}",
) from e
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
return {"status": "ok"}
@admin_router.get("/providers/{provider_id}/voices")
def get_provider_voices(
provider_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[dict[str, str]]:
"""Get available voices for a provider."""
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
if provider_db is None:
raise HTTPException(status_code=404, detail="Voice provider not found.")
if not provider_db.api_key:
raise HTTPException(
status_code=400, detail="Provider has no API key configured."
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return provider.get_available_voices()
@admin_router.get("/voices")
def get_voices_by_type(
provider_type: str,
_: User = Depends(current_admin_user),
) -> list[dict[str, str]]:
"""Get available voices for a provider type.
For providers like ElevenLabs and OpenAI, this fetches voices
without requiring an existing provider configuration.
"""
# Create a temporary VoiceProvider to get static voice list
temp_provider = VoiceProvider(
name="__temp__",
provider_type=provider_type,
)
try:
provider = get_voice_provider(temp_provider)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
return provider.get_available_voices()

View File

@@ -0,0 +1,90 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
class VoiceProviderView(BaseModel):
"""Response model for voice provider listing."""
id: int
name: str
provider_type: str # "openai", "azure", "elevenlabs"
is_default_stt: bool
is_default_tts: bool
stt_model: str | None
tts_model: str | None
default_voice: str | None
has_api_key: bool = Field(
default=False,
description="Indicates whether an API key is stored for this provider.",
)
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services.",
)
class VoiceProviderUpsertRequest(BaseModel):
"""Request model for creating or updating a voice provider."""
id: int | None = Field(default=None, description="Existing provider ID to update.")
name: str
provider_type: str # "openai", "azure", "elevenlabs"
api_key: str | None = Field(
default=None,
description="API key for the provider.",
)
api_key_changed: bool = Field(
default=False,
description="Set to true when providing a new API key for an existing provider.",
)
llm_provider_id: int | None = Field(
default=None,
description="If set, copies the API key from the specified LLM provider.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None
stt_model: str | None = None
tts_model: str | None = None
default_voice: str | None = None
activate_stt: bool = Field(
default=False,
description="If true, sets this provider as the default STT provider after upsert.",
)
activate_tts: bool = Field(
default=False,
description="If true, sets this provider as the default TTS provider after upsert.",
)
class VoiceProviderTestRequest(BaseModel):
"""Request model for testing a voice provider connection."""
provider_type: str
api_key: str | None = Field(
default=None,
description="API key for testing. If not provided, use_stored_key must be true.",
)
use_stored_key: bool = Field(
default=False,
description="If true, use the stored API key for this provider type.",
)
api_base: str | None = None
target_uri: str | None = Field(
default=None,
description="Target URI for Azure Speech Services (maps to api_base).",
)
custom_config: dict[str, Any] | None = None
class SynthesizeRequest(BaseModel):
"""Request model for text-to-speech synthesis."""
text: str = Field(..., min_length=1, max_length=4096)
voice: str | None = None
speed: float = Field(default=1.0, ge=0.5, le=2.0)

View File

@@ -0,0 +1,226 @@
import secrets
from collections.abc import AsyncIterator
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.db.voice import update_user_voice_settings
from onyx.redis.redis_pool import store_ws_token
from onyx.server.manage.models import VoiceSettingsUpdateRequest
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Max audio file size: 25MB (Whisper limit)
MAX_AUDIO_SIZE = 25 * 1024 * 1024
@router.post("/transcribe")
async def transcribe_audio(
audio: UploadFile = File(...),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Transcribe audio to text using the default STT provider."""
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
raise HTTPException(
status_code=400,
detail="No speech-to-text provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
audio_data = await audio.read()
if len(audio_data) > MAX_AUDIO_SIZE:
raise HTTPException(
status_code=400,
detail=f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
)
# Extract format from filename
filename = audio.filename or "audio.webm"
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
try:
text = await provider.transcribe(audio_data, audio_format)
return {"text": text}
except NotImplementedError as exc:
raise HTTPException(
status_code=501,
detail=f"Speech-to-text not implemented for {provider_db.provider_type}.",
) from exc
except Exception as exc:
logger.error(f"Transcription failed: {exc}")
raise HTTPException(
status_code=500,
detail=f"Transcription failed: {str(exc)}",
) from exc
@router.post("/synthesize")
async def synthesize_speech(
text: str | None = Query(
default=None, description="Text to synthesize", max_length=4096
),
voice: str | None = Query(default=None, description="Voice ID to use"),
speed: float | None = Query(
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
),
user: User = Depends(current_user),
) -> StreamingResponse:
"""
Synthesize text to speech using the default TTS provider.
Accepts parameters via query string for streaming compatibility.
"""
logger.info(
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
)
if not text:
raise HTTPException(status_code=400, detail="Text is required")
# Use short-lived session to fetch provider config, then release connection
# before starting the long-running streaming response
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.error("No TTS provider configured")
raise HTTPException(
status_code=400,
detail="No text-to-speech provider configured. Please contact your administrator.",
)
if not provider_db.api_key:
logger.error("TTS provider has no API key")
raise HTTPException(
status_code=400,
detail="Voice provider API key not configured.",
)
# Use request voice, or user's preferred voice, or provider default
final_voice = voice or user.preferred_voice or provider_db.default_voice
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
final_speed = (
speed
if speed is not None
else (
user.voice_playback_speed
if user.voice_playback_speed is not None
else 1.0
)
)
logger.info(
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
)
try:
provider = get_voice_provider(provider_db)
except ValueError as exc:
logger.error(f"Failed to get voice provider: {exc}")
raise HTTPException(status_code=500, detail=str(exc)) from exc
# Session is now closed - streaming response won't hold DB connection
async def audio_stream() -> AsyncIterator[bytes]:
try:
chunk_count = 0
async for chunk in provider.synthesize_stream(
text=text, voice=final_voice, speed=final_speed
):
chunk_count += 1
yield chunk
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
except NotImplementedError as exc:
logger.error(f"TTS not implemented: {exc}")
raise
except Exception as exc:
logger.error(f"Synthesis failed: {exc}")
raise
return StreamingResponse(
audio_stream(),
media_type="audio/mpeg",
headers={
"Content-Disposition": "inline; filename=speech.mp3",
# Allow streaming by not setting content-length
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
)
@router.patch("/settings")
def update_voice_settings(
request: VoiceSettingsUpdateRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict[str, str]:
"""Update user's voice settings.
To clear preferred_voice back to default, explicitly send {"preferred_voice": null}.
Omitting the field leaves it unchanged.
"""
# Build kwargs, only including preferred_voice if explicitly provided
# This allows distinguishing between "not provided" and "set to null"
kwargs: dict[str, Any] = {
"db_session": db_session,
"user_id": user.id,
"auto_send": request.auto_send,
"auto_playback": request.auto_playback,
"playback_speed": request.playback_speed,
}
if "preferred_voice" in request.model_fields_set:
kwargs["preferred_voice"] = request.preferred_voice
update_user_voice_settings(**kwargs)
return {"status": "ok"}
class WSTokenResponse(BaseModel):
token: str
@router.post("/ws-token")
async def get_ws_token(
user: User = Depends(current_user),
) -> WSTokenResponse:
"""
Generate a short-lived token for WebSocket authentication.
This token should be passed as a query parameter when connecting
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
The token expires after 60 seconds and is single-use.
"""
token = secrets.token_urlsafe(32)
await store_ws_token(token, str(user.id))
return WSTokenResponse(token=token)

View File

@@ -0,0 +1,782 @@
"""WebSocket API for streaming speech-to-text and text-to-speech."""
import asyncio
import io
import json
import os
from collections.abc import MutableMapping
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import WebSocket
from fastapi import WebSocketDisconnect
from sqlalchemy.orm import Session
from onyx.auth.users import current_user_from_websocket
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.db.models import User
from onyx.db.voice import fetch_default_stt_provider
from onyx.db.voice import fetch_default_tts_provider
from onyx.utils.logger import setup_logger
from onyx.voice.factory import get_voice_provider
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
logger = setup_logger()
router = APIRouter(prefix="/voice")
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
MIN_CHUNK_BYTES = 1500
VOICE_DISABLE_STREAMING_FALLBACK = (
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
)
class ChunkedTranscriber:
"""Fallback transcriber for providers without streaming support."""
def __init__(self, provider: Any, audio_format: str = "webm"):
self.provider = provider
self.audio_format = audio_format
self.chunk_buffer = io.BytesIO()
self.full_audio = io.BytesIO()
self.chunk_bytes = 0
self.transcripts: list[str] = []
async def add_chunk(self, chunk: bytes) -> str | None:
"""Add audio chunk. Returns transcript if enough audio accumulated."""
self.chunk_buffer.write(chunk)
self.full_audio.write(chunk)
self.chunk_bytes += len(chunk)
if self.chunk_bytes >= MIN_CHUNK_BYTES:
return await self._transcribe_chunk()
return None
async def _transcribe_chunk(self) -> str | None:
"""Transcribe current chunk and append to running transcript."""
audio_data = self.chunk_buffer.getvalue()
if not audio_data:
return None
try:
transcript = await self.provider.transcribe(audio_data, self.audio_format)
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
if transcript and transcript.strip():
self.transcripts.append(transcript.strip())
return " ".join(self.transcripts)
return None
except Exception as e:
logger.error(f"Transcription error: {e}")
self.chunk_buffer = io.BytesIO()
self.chunk_bytes = 0
return None
async def flush(self) -> str:
"""Get final transcript from full audio for best accuracy."""
full_audio_data = self.full_audio.getvalue()
if full_audio_data:
try:
transcript = await self.provider.transcribe(
full_audio_data, self.audio_format
)
if transcript and transcript.strip():
return transcript.strip()
except Exception as e:
logger.error(f"Final transcription error: {e}")
return " ".join(self.transcripts)
async def handle_streaming_transcription(
websocket: WebSocket,
transcriber: StreamingTranscriberProtocol,
) -> None:
"""Handle transcription using native streaming API."""
logger.info("Streaming transcription: starting handler")
last_transcript = ""
chunk_count = 0
total_bytes = 0
async def receive_transcripts() -> None:
"""Background task to receive and send transcripts."""
nonlocal last_transcript
logger.info("Streaming transcription: starting transcript receiver")
while True:
result: TranscriptResult | None = await transcriber.receive_transcript()
if result is None: # End of stream
logger.info("Streaming transcription: transcript stream ended")
break
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
if result.text and (result.text != last_transcript or result.is_vad_end):
last_transcript = result.text
logger.debug(
f"Streaming transcription: got transcript: {result.text[:50]}... "
f"(is_vad_end={result.is_vad_end})"
)
await websocket.send_json(
{
"type": "transcript",
"text": result.text,
"is_final": result.is_vad_end,
}
)
# Start receiving transcripts in background
receive_task = asyncio.create_task(receive_transcripts())
try:
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
await transcriber.send_audio(message["bytes"])
elif "text" in message:
try:
data = json.loads(message["text"])
logger.debug(
f"Streaming transcription: received text message: {data}"
)
if data.get("type") == "end":
logger.info(
"Streaming transcription: end signal received, closing transcriber"
)
final_transcript = await transcriber.close()
receive_task.cancel()
logger.info(
"Streaming transcription: final transcript: "
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
elif data.get("type") == "reset":
# Reset accumulated transcript after auto-send
logger.info(
"Streaming transcription: reset signal received, clearing transcript"
)
transcriber.reset_transcript()
except json.JSONDecodeError:
logger.warning(
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
except Exception as e:
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
raise
finally:
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
logger.info(
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
async def handle_chunked_transcription(
websocket: WebSocket,
transcriber: ChunkedTranscriber,
) -> None:
"""Handle transcription using chunked batch API."""
logger.info("Chunked transcription: starting handler")
chunk_count = 0
total_bytes = 0
while True:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info(
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
)
break
if "bytes" in message:
chunk_size = len(message["bytes"])
chunk_count += 1
total_bytes += chunk_size
logger.debug(
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
)
transcript = await transcriber.add_chunk(message["bytes"])
if transcript:
logger.debug(
f"Chunked transcription: got transcript: {transcript[:50]}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": transcript,
"is_final": False,
}
)
elif "text" in message:
try:
data = json.loads(message["text"])
logger.debug(f"Chunked transcription: received text message: {data}")
if data.get("type") == "end":
logger.info("Chunked transcription: end signal received, flushing")
final_transcript = await transcriber.flush()
logger.info(
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
)
await websocket.send_json(
{
"type": "transcript",
"text": final_transcript,
"is_final": True,
}
)
break
except json.JSONDecodeError:
logger.warning(
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
)
logger.info(
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
)
@router.websocket("/transcribe/stream")
async def websocket_transcribe(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming speech-to-text.
Protocol:
- Client sends binary audio chunks
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
- Client sends JSON {"type": "end"} to signal end
- Server responds with final transcript and closes
Authentication:
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket transcribe: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket transcribe: connection accepted")
except Exception as e:
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
return
streaming_transcriber = None
provider = None
try:
# Get STT provider
logger.info("WebSocket transcribe: fetching STT provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_stt_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket transcribe: no default STT provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No speech-to-text provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket transcribe: STT provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Speech-to-text provider has no API key configured",
}
)
return
logger.info(
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
)
except ValueError as e:
logger.error(
f"WebSocket transcribe: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_stt():
logger.info("WebSocket transcribe: using native streaming STT")
try:
streaming_transcriber = await provider.create_streaming_transcriber()
logger.info(
"WebSocket transcribe: streaming transcriber created successfully"
)
await handle_streaming_transcription(websocket, streaming_transcriber)
except Exception as e:
logger.error(
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming STT failed: {e}"}
)
return
logger.info("WebSocket transcribe: falling back to chunked STT")
# Browser stream provides raw PCM16 chunks over WebSocket.
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
else:
# Fall back to chunked transcription
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming STT",
}
)
return
logger.info(
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
)
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
await handle_chunked_transcription(websocket, chunked_transcriber)
except WebSocketDisconnect:
logger.debug("WebSocket transcribe: client disconnected")
except Exception as e:
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
try:
# Send generic error to avoid leaking sensitive details
await websocket.send_json(
{"type": "error", "message": "An unexpected error occurred"}
)
except Exception:
pass
finally:
if streaming_transcriber:
try:
await streaming_transcriber.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket transcribe: connection closed")
async def handle_streaming_synthesis(
websocket: WebSocket,
synthesizer: StreamingSynthesizerProtocol,
) -> None:
"""Handle TTS using native streaming API."""
logger.info("Streaming synthesis: starting handler")
async def send_audio() -> None:
"""Background task to send audio chunks to client."""
chunk_count = 0
total_bytes = 0
try:
while True:
audio_chunk = await synthesizer.receive_audio()
if audio_chunk is None:
logger.info(
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
)
try:
await websocket.send_json({"type": "audio_done"})
logger.info("Streaming synthesis: sent audio_done to client")
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send audio_done: {e}"
)
break
if audio_chunk: # Skip empty chunks
chunk_count += 1
total_bytes += len(audio_chunk)
try:
await websocket.send_bytes(audio_chunk)
except Exception as e:
logger.warning(
f"Streaming synthesis: failed to send chunk: {e}"
)
break
except asyncio.CancelledError:
logger.info(
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
)
except Exception as e:
logger.error(f"Streaming synthesis: send_audio error: {e}")
send_task: asyncio.Task | None = None
disconnected = False
try:
while not disconnected:
try:
message = await websocket.receive()
except WebSocketDisconnect:
logger.info("Streaming synthesis: client disconnected")
break
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
if msg_type == "websocket.disconnect":
logger.info("Streaming synthesis: client disconnected")
disconnected = True
break
if "text" in message:
try:
data = json.loads(message["text"])
if data.get("type") == "synthesize":
text = data.get("text", "")
if not text:
for key, value in data.items():
if key != "type" and isinstance(value, str) and value:
text = value
break
if text:
# Start audio receiver on first text chunk so playback
# can begin before the full assistant response completes.
if send_task is None:
send_task = asyncio.create_task(send_audio())
logger.debug(
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
)
await synthesizer.send_text(text)
elif data.get("type") == "end":
logger.info("Streaming synthesis: end signal received")
# Ensure receiver is active even if no prior text chunks arrived.
if send_task is None:
send_task = asyncio.create_task(send_audio())
# Signal end of input
if hasattr(synthesizer, "flush"):
await synthesizer.flush()
# Wait for all audio to be sent
logger.info(
"Streaming synthesis: waiting for audio stream to complete"
)
try:
await asyncio.wait_for(send_task, timeout=60.0)
except asyncio.TimeoutError:
logger.warning(
"Streaming synthesis: timeout waiting for audio"
)
break
except json.JSONDecodeError:
logger.warning(
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
)
except WebSocketDisconnect:
logger.debug("Streaming synthesis: client disconnected during synthesis")
except Exception as e:
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
finally:
if send_task and not send_task.done():
logger.info("Streaming synthesis: waiting for send_task to finish")
try:
await asyncio.wait_for(send_task, timeout=30.0)
except asyncio.TimeoutError:
logger.warning("Streaming synthesis: timeout waiting for send_task")
send_task.cancel()
try:
await send_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
pass
logger.info("Streaming synthesis: handler finished")
async def handle_chunked_synthesis(
websocket: WebSocket,
provider: Any,
first_message: MutableMapping[str, Any] | None = None,
) -> None:
"""Fallback TTS handler using provider.synthesize_stream.
Args:
websocket: The WebSocket connection
provider: Voice provider instance
first_message: Optional first message already received (used when falling
back from streaming mode, where the first message was already consumed)
"""
logger.info("Chunked synthesis: starting handler")
text_buffer: list[str] = []
voice: str | None = None
speed = 1.0
# Process pre-received message if provided
pending_message = first_message
try:
while True:
if pending_message is not None:
message = pending_message
pending_message = None
else:
message = await websocket.receive()
msg_type = message.get("type", "unknown")
if msg_type == "websocket.disconnect":
logger.info("Chunked synthesis: client disconnected")
break
if "text" not in message:
continue
try:
data = json.loads(message["text"])
except json.JSONDecodeError:
logger.warning(
"Chunked synthesis: failed to parse JSON: "
f"{message.get('text', '')[:100]}"
)
continue
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
if msg_data_type == "synthesize":
text = data.get("text", "")
if not text:
for key, value in data.items():
if key != "type" and isinstance(value, str) and value:
text = value
break
if text:
text_buffer.append(text)
logger.debug(
f"Chunked synthesis: buffered text ({len(text)} chars), "
f"total buffered: {len(text_buffer)} chunks"
)
if isinstance(data.get("voice"), str) and data["voice"]:
voice = data["voice"]
if isinstance(data.get("speed"), (int, float)):
speed = float(data["speed"])
elif msg_data_type == "end":
logger.info("Chunked synthesis: end signal received")
full_text = " ".join(text_buffer).strip()
if not full_text:
await websocket.send_json({"type": "audio_done"})
logger.info("Chunked synthesis: no text, sent audio_done")
break
chunk_count = 0
total_bytes = 0
logger.info(
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
)
async for audio_chunk in provider.synthesize_stream(
full_text, voice=voice, speed=speed
):
if not audio_chunk:
continue
chunk_count += 1
total_bytes += len(audio_chunk)
await websocket.send_bytes(audio_chunk)
await websocket.send_json({"type": "audio_done"})
logger.info(
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
)
break
except WebSocketDisconnect:
logger.debug("Chunked synthesis: client disconnected")
except Exception as e:
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
raise
finally:
logger.info("Chunked synthesis: handler finished")
@router.websocket("/synthesize/stream")
async def websocket_synthesize(
websocket: WebSocket,
_user: User = Depends(current_user_from_websocket),
) -> None:
"""
WebSocket endpoint for streaming text-to-speech.
Protocol:
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
- Server sends binary audio chunks
- Server sends JSON: {"type": "audio_done"} when synthesis completes
- Client sends JSON {"type": "end"} to close connection
Authentication:
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
Applies same auth checks as HTTP endpoints (verification, role checks).
"""
logger.info("WebSocket synthesize: connection request received (authenticated)")
try:
await websocket.accept()
logger.info("WebSocket synthesize: connection accepted")
except Exception as e:
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
return
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
provider = None
try:
# Get TTS provider
logger.info("WebSocket synthesize: fetching TTS provider from database")
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
provider_db = fetch_default_tts_provider(db_session)
if provider_db is None:
logger.warning(
"WebSocket synthesize: no default TTS provider configured"
)
await websocket.send_json(
{
"type": "error",
"message": "No text-to-speech provider configured",
}
)
return
if not provider_db.api_key:
logger.warning("WebSocket synthesize: TTS provider has no API key")
await websocket.send_json(
{
"type": "error",
"message": "Text-to-speech provider has no API key configured",
}
)
return
logger.info(
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
)
try:
provider = get_voice_provider(provider_db)
logger.info(
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
)
except ValueError as e:
logger.error(
f"WebSocket synthesize: failed to create voice provider: {e}"
)
await websocket.send_json({"type": "error", "message": str(e)})
return
# Use native streaming if provider supports it
if provider.supports_streaming_tts():
logger.info("WebSocket synthesize: using native streaming TTS")
message = None # Initialize to avoid UnboundLocalError in except block
try:
# Wait for initial config message with voice/speed
message = await websocket.receive()
voice = None
speed = 1.0
if "text" in message:
try:
data = json.loads(message["text"])
voice = data.get("voice")
speed = data.get("speed", 1.0)
except json.JSONDecodeError:
pass
streaming_synthesizer = await provider.create_streaming_synthesizer(
voice=voice, speed=speed
)
logger.info(
"WebSocket synthesize: streaming synthesizer created successfully"
)
await handle_streaming_synthesis(websocket, streaming_synthesizer)
except Exception as e:
logger.error(
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
)
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{"type": "error", "message": f"Streaming TTS failed: {e}"}
)
return
logger.info(
"WebSocket synthesize: falling back to chunked TTS synthesis"
)
# Pass the first message so it's not lost in the fallback
await handle_chunked_synthesis(
websocket, provider, first_message=message
)
else:
if VOICE_DISABLE_STREAMING_FALLBACK:
await websocket.send_json(
{
"type": "error",
"message": "Provider doesn't support streaming TTS",
}
)
return
logger.info(
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
)
await handle_chunked_synthesis(websocket, provider)
except WebSocketDisconnect:
logger.debug("WebSocket synthesize: client disconnected")
except Exception as e:
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
try:
# Send generic error to avoid leaking sensitive details
await websocket.send_json(
{"type": "error", "message": "An unexpected error occurred"}
)
except Exception:
pass
finally:
if streaming_synthesizer:
try:
await streaming_synthesizer.close()
except Exception:
pass
try:
await websocket.close()
except Exception:
pass
logger.info("WebSocket synthesize: connection closed")

View File

@@ -1,6 +1,5 @@
import datetime
import json
import os
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
@@ -61,7 +60,6 @@ from onyx.db.persona import get_persona_by_id
from onyx.db.usage import increment_usage
from onyx.db.usage import UsageType
from onyx.db.user_file import get_file_id_by_user_file_id
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_default_llm
@@ -812,18 +810,6 @@ def fetch_chat_file(
if not file_record:
raise HTTPException(status_code=404, detail="File not found")
original_file_name = file_record.display_name
if file_record.file_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
# Check if a converted text file exists for .docx files
txt_file_name = docx_to_txt_filename(original_file_name)
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
txt_file_record = file_store.read_file_record(txt_file_id)
if txt_file_record:
file_record = txt_file_record
file_id = txt_file_id
media_type = file_record.file_type
file_io = file_store.read_file(file_id, mode="b")

View File

@@ -60,9 +60,11 @@ class Settings(BaseModel):
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import json
import time
from collections.abc import Generator
@@ -84,6 +86,19 @@ class CodeInterpreterClient:
raise ValueError("CODE_INTERPRETER_BASE_URL not configured")
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
self._closed = False
def __enter__(self) -> CodeInterpreterClient:
return self
def __exit__(self, *args: object) -> None:
self.close()
def close(self) -> None:
if self._closed:
return
self.session.close()
self._closed = True
def _build_payload(
self,
@@ -177,8 +192,11 @@ class CodeInterpreterClient:
yield from self._batch_as_stream(code, stdin, timeout_ms, files)
return
response.raise_for_status()
yield from self._parse_sse(response)
try:
response.raise_for_status()
yield from self._parse_sse(response)
finally:
response.close()
def _parse_sse(
self, response: requests.Response

View File

@@ -111,8 +111,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
if not server.server_enabled:
return False
client = CodeInterpreterClient()
return client.health(use_cache=True)
with CodeInterpreterClient() as client:
return client.health(use_cache=True)
def tool_definition(self) -> dict:
return {
@@ -176,196 +176,203 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
)
)
# Create Code Interpreter client
client = CodeInterpreterClient()
# Create Code Interpreter client — context manager ensures
# session.close() is called on every exit path.
with CodeInterpreterClient() as client:
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
logger.info(f"Staged file for Python execution: {file_name}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
# Stage chat files for execution
files_to_stage: list[FileInput] = []
for ind, chat_file in enumerate(chat_files):
file_name = chat_file.filename or f"file_{ind}"
try:
# Upload to Code Interpreter
ci_file_id = client.upload_file(chat_file.content, file_name)
logger.debug(f"Executing code: {code}")
# Stage for execution
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
logger.info(f"Staged file for Python execution: {file_name}")
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=(
event.data if event.stream == "stdout" else ""
),
stderr=(
event.data if event.stream == "stderr" else ""
),
),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
except Exception as e:
logger.warning(f"Failed to stage file {file_name}: {e}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
)
try:
logger.debug(f"Executing code: {code}")
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
# Execute code with streaming (falls back to batch if unavailable)
stdout_parts: list[str] = []
stderr_parts: list[str] = []
result_event: StreamResultEvent | None = None
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
for event in client.execute_streaming(
code=code,
timeout_ms=CODE_INTERPRETER_DEFAULT_TIMEOUT_MS,
files=files_to_stage or None,
):
if isinstance(event, StreamOutputEvent):
if event.stream == "stdout":
stdout_parts.append(event.data)
else:
stderr_parts.append(event.data)
# Emit incremental delta to frontend
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file "
f"{workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated "
f"file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged "
f"file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=event.data if event.stream == "stdout" else "",
stderr=event.data if event.stream == "stderr" else "",
),
obj=PythonToolDelta(file_ids=generated_file_ids),
)
)
elif isinstance(event, StreamResultEvent):
result_event = event
elif isinstance(event, StreamErrorEvent):
raise RuntimeError(f"Code interpreter error: {event.message}")
if result_event is None:
raise RuntimeError(
"Code interpreter stream ended without a result event"
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=(None if result_event.exit_code == 0 else truncated_stderr),
)
full_stdout = "".join(stdout_parts)
full_stderr = "".join(stderr_parts)
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
# Truncate output for LLM consumption
truncated_stdout = _truncate_output(
full_stdout, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stdout"
)
truncated_stderr = _truncate_output(
full_stderr, CODE_INTERPRETER_MAX_OUTPUT_LENGTH, "stderr"
)
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
# Handle generated files
generated_files: list[PythonExecutionFile] = []
generated_file_ids: list[str] = []
file_ids_to_cleanup: list[str] = []
file_store = get_default_file_store()
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
for workspace_file in result_event.files:
if workspace_file.kind != "file" or not workspace_file.file_id:
continue
try:
# Download file from Code Interpreter
file_content = client.download_file(workspace_file.file_id)
# Determine MIME type from file extension
filename = workspace_file.path.split("/")[-1]
mime_type, _ = mimetypes.guess_type(filename)
# Default to binary if we can't determine the type
mime_type = mime_type or "application/octet-stream"
# Save to Onyx file store
onyx_file_id = file_store.save_file(
content=BytesIO(file_content),
display_name=filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=mime_type,
)
generated_files.append(
PythonExecutionFile(
filename=filename,
file_link=build_full_frontend_file_url(onyx_file_id),
)
)
generated_file_ids.append(onyx_file_id)
# Mark for cleanup
file_ids_to_cleanup.append(workspace_file.file_id)
except Exception as e:
logger.error(
f"Failed to handle generated file {workspace_file.path}: {e}"
)
# Cleanup Code Interpreter files (generated files)
for ci_file_id in file_ids_to_cleanup:
try:
client.delete_file(ci_file_id)
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
)
# Cleanup staged input files
for file_mapping in files_to_stage:
try:
client.delete_file(file_mapping["file_id"])
except Exception as e:
logger.error(
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
)
# Emit file_ids once files are processed
if generated_file_ids:
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(file_ids=generated_file_ids),
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
)
)
# Build result
result = LlmPythonExecutionResult(
stdout=truncated_stdout,
stderr=truncated_stderr,
exit_code=result_event.exit_code,
timed_out=result_event.timed_out,
generated_files=generated_files,
error=None if result_event.exit_code == 0 else truncated_stderr,
)
# Serialize result for LLM
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=PythonToolRichResponse(
generated_files=generated_files,
),
llm_facing_response=llm_response,
)
except Exception as e:
logger.error(f"Python execution failed: {e}")
error_msg = str(e)
# Emit error delta
self.emitter.emit(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout="",
stderr=error_msg,
file_ids=[],
),
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
)
# Return error result
result = LlmPythonExecutionResult(
stdout="",
stderr=error_msg,
exit_code=-1,
timed_out=False,
generated_files=[],
error=error_msg,
)
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
adapter = TypeAdapter(LlmPythonExecutionResult)
llm_response = adapter.dump_json(result).decode()
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)
return ToolResponse(
rich_response=None,
llm_facing_response=llm_response,
)

View File

View File

@@ -0,0 +1,70 @@
from onyx.db.models import VoiceProvider
from onyx.voice.interface import VoiceProviderInterface
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
"""
Factory function to get the appropriate voice provider implementation.
Args:
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
Returns:
VoiceProviderInterface implementation
Raises:
ValueError: If provider_type is not supported
"""
provider_type = provider.provider_type.lower()
# Handle both SensitiveValue (from DB) and plain string (from temp model)
if provider.api_key is None:
api_key = None
elif hasattr(provider.api_key, "get_value"):
# SensitiveValue from database
api_key = provider.api_key.get_value(apply_mask=False)
else:
# Plain string from temporary model
api_key = provider.api_key # type: ignore[assignment]
api_base = provider.api_base
custom_config = provider.custom_config
stt_model = provider.stt_model
tts_model = provider.tts_model
default_voice = provider.default_voice
if provider_type == "openai":
from onyx.voice.providers.openai import OpenAIVoiceProvider
return OpenAIVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "azure":
from onyx.voice.providers.azure import AzureVoiceProvider
return AzureVoiceProvider(
api_key=api_key,
api_base=api_base,
custom_config=custom_config or {},
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
elif provider_type == "elevenlabs":
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
return ElevenLabsVoiceProvider(
api_key=api_key,
api_base=api_base,
stt_model=stt_model,
tts_model=tts_model,
default_voice=default_voice,
)
else:
raise ValueError(f"Unsupported voice provider type: {provider_type}")

View File

@@ -0,0 +1,175 @@
from abc import ABC
from abc import abstractmethod
from collections.abc import AsyncIterator
from typing import Protocol
from pydantic import BaseModel
class TranscriptResult(BaseModel):
"""Result from streaming transcription."""
text: str
"""The accumulated transcript text."""
is_vad_end: bool = False
"""True if VAD detected end of speech (silence). Use for auto-send."""
class StreamingTranscriberProtocol(Protocol):
"""Protocol for streaming transcription sessions."""
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
...
async def receive_transcript(self) -> TranscriptResult | None:
"""
Receive next transcript update.
Returns:
TranscriptResult with accumulated text and VAD status, or None when stream ends.
"""
...
async def close(self) -> str:
"""Close the session and return final transcript."""
...
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
...
class StreamingSynthesizerProtocol(Protocol):
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
async def connect(self) -> None:
"""Establish connection to TTS provider."""
...
async def send_text(self, text: str) -> None:
"""Send text to be synthesized."""
...
async def receive_audio(self) -> bytes | None:
"""
Receive next audio chunk.
Returns:
Audio bytes, or None when stream ends.
"""
...
async def flush(self) -> None:
"""Signal end of text input and wait for pending audio."""
...
async def close(self) -> None:
"""Close the session."""
...
class VoiceProviderInterface(ABC):
"""Abstract base class for voice providers (STT and TTS)."""
@abstractmethod
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Convert audio to text (Speech-to-Text).
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
@abstractmethod
def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio stream (Text-to-Speech).
Streams audio chunks progressively for lower latency playback.
Args:
text: Text to convert to speech
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks
"""
@abstractmethod
def get_available_voices(self) -> list[dict[str, str]]:
"""
Get list of available voices for this provider.
Returns:
List of voice dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_stt_models(self) -> list[dict[str, str]]:
"""
Get list of available STT models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
@abstractmethod
def get_available_tts_models(self) -> list[dict[str, str]]:
"""
Get list of available TTS models for this provider.
Returns:
List of model dictionaries with 'id' and 'name' keys
"""
def supports_streaming_stt(self) -> bool:
"""Returns True if this provider supports streaming STT."""
return False
def supports_streaming_tts(self) -> bool:
"""Returns True if this provider supports real-time streaming TTS."""
return False
async def create_streaming_transcriber(
self, audio_format: str = "webm"
) -> StreamingTranscriberProtocol:
"""
Create a streaming transcription session.
Args:
audio_format: Audio format being sent (e.g., "webm", "pcm16")
Returns:
A streaming transcriber that can send audio chunks and receive transcripts
Raises:
NotImplementedError: If streaming STT is not supported
"""
raise NotImplementedError("Streaming STT not supported by this provider")
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> "StreamingSynthesizerProtocol":
"""
Create a streaming TTS session for real-time audio synthesis.
Args:
voice: Voice identifier
speed: Playback speed multiplier
Returns:
A streaming synthesizer that can send text and receive audio chunks
Raises:
NotImplementedError: If streaming TTS is not supported
"""
raise NotImplementedError("Streaming TTS not supported by this provider")

View File

View File

@@ -0,0 +1,493 @@
import asyncio
import io
import re
import struct
import wave
from collections.abc import AsyncIterator
from typing import Any
from xml.sax.saxutils import escape
from xml.sax.saxutils import quoteattr
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# Common Azure Neural voices
AZURE_VOICES = [
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
]
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using Azure Speech SDK."""
def __init__(
self,
api_key: str,
region: str,
input_sample_rate: int = 24000,
target_sample_rate: int = 16000,
):
self.api_key = api_key
self.region = region
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._accumulated_transcript = ""
self._recognizer: Any = None
self._audio_stream: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech recognizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming STT. "
"Install `azure-cognitiveservices-speech`."
) from e
self._loop = asyncio.get_running_loop()
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
audio_format = speechsdk.audio.AudioStreamFormat(
samples_per_second=16000,
bits_per_sample=16,
channels=1,
)
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
self._recognizer = speechsdk.SpeechRecognizer(
speech_config=speech_config,
audio_config=audio_config,
)
transcriber = self
def on_recognizing(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
full_text = transcriber._accumulated_transcript
if full_text:
full_text += " " + evt.result.text
else:
full_text = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(text=full_text, is_vad_end=False),
)
def on_recognized(evt: Any) -> None:
if evt.result.text and transcriber._loop and not transcriber._closed:
if transcriber._accumulated_transcript:
transcriber._accumulated_transcript += " " + evt.result.text
else:
transcriber._accumulated_transcript = evt.result.text
transcriber._loop.call_soon_threadsafe(
transcriber._transcript_queue.put_nowait,
TranscriptResult(
text=transcriber._accumulated_transcript, is_vad_end=True
),
)
self._recognizer.recognizing.connect(on_recognizing)
self._recognizer.recognized.connect(on_recognized)
self._recognizer.start_continuous_recognition_async()
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to Azure."""
if self._audio_stream and not self._closed:
self._audio_stream.write(self._resample_pcm16(chunk))
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
if self.input_sample_rate == self.target_sample_rate:
return data
num_samples = len(data) // 2
if num_samples == 0:
return b""
samples = list(struct.unpack(f"<{num_samples}h", data))
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
resampled: list[int] = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Stop recognition and return final transcript."""
self._closed = True
if self._recognizer:
self._recognizer.stop_continuous_recognition_async()
if self._audio_stream:
self._audio_stream.close()
self._loop = None
return self._accumulated_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript."""
self._accumulated_transcript = ""
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using Azure Speech SDK."""
def __init__(
self,
api_key: str,
region: str,
voice: str = "en-US-JennyNeural",
speed: float = 1.0,
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.region = region
self.voice = voice
self.speed = max(0.5, min(2.0, speed))
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._synthesizer: Any = None
self._closed = False
self._loop: asyncio.AbstractEventLoop | None = None
async def connect(self) -> None:
"""Initialize Azure Speech synthesizer with push stream."""
try:
import azure.cognitiveservices.speech as speechsdk
except ImportError as e:
raise RuntimeError(
"Azure Speech SDK is required for streaming TTS. "
"Install `azure-cognitiveservices-speech`."
) from e
self._logger.info("AzureStreamingSynthesizer: connecting")
# Store the event loop for thread-safe queue operations
self._loop = asyncio.get_running_loop()
speech_config = speechsdk.SpeechConfig(
subscription=self.api_key,
region=self.region,
)
speech_config.speech_synthesis_voice_name = self.voice
# Use MP3 format for streaming - compatible with MediaSource Extensions
speech_config.set_speech_synthesis_output_format(
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
)
# Create synthesizer with pull audio output stream
self._synthesizer = speechsdk.SpeechSynthesizer(
speech_config=speech_config,
audio_config=None, # We'll manually handle audio
)
# Connect to synthesis events
self._synthesizer.synthesizing.connect(self._on_synthesizing)
self._synthesizer.synthesis_completed.connect(self._on_completed)
self._logger.info("AzureStreamingSynthesizer: connected")
def _on_synthesizing(self, evt: Any) -> None:
"""Called when audio chunk is available (runs in Azure SDK thread)."""
if evt.result.audio_data and self._loop and not self._closed:
# Thread-safe way to put item in async queue
self._loop.call_soon_threadsafe(
self._audio_queue.put_nowait, evt.result.audio_data
)
def _on_completed(self, _evt: Any) -> None:
"""Called when synthesis is complete (runs in Azure SDK thread)."""
if self._loop and not self._closed:
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
async def send_text(self, text: str) -> None:
"""Send text to be synthesized using SSML for prosody control."""
if self._synthesizer and not self._closed:
# Build SSML with prosody for speed control
rate = f"{int((self.speed - 1) * 100):+d}%"
escaped_text = escape(text)
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
<voice name={quoteattr(self.voice)}>
<prosody rate='{rate}'>{escaped_text}</prosody>
</voice>
</speak>"""
# Use speak_ssml_async for SSML support (includes speed/prosody)
self._synthesizer.speak_ssml_async(ssml)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for pending audio."""
# Azure SDK handles flushing automatically
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._synthesizer:
self._synthesizer.synthesis_completed.disconnect_all()
self._synthesizer.synthesizing.disconnect_all()
self._loop = None
class AzureVoiceProvider(VoiceProviderInterface):
"""Azure Speech Services voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None,
custom_config: dict[str, Any],
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.custom_config = custom_config
self.speech_region = (
custom_config.get("speech_region")
or self._extract_speech_region_from_uri(api_base)
or ""
)
self.stt_model = stt_model
self.tts_model = tts_model
self.default_voice = default_voice or "en-US-JennyNeural"
@staticmethod
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
"""Extract Azure speech region from endpoint URI.
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
name, not the region. For custom domains, the region must be specified
explicitly via custom_config["speech_region"].
"""
if not uri:
return None
# Accepted examples:
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
# - https://westus.api.cognitive.microsoft.com/
#
# NOT supported (requires explicit speech_region config):
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
patterns = [
r"https?://([^.]+)\.(?:tts|stt)\.speech\.microsoft\.com",
r"https?://([^.]+)\.api\.cognitive\.microsoft\.com",
]
for pattern in patterns:
match = re.search(pattern, uri)
if match:
return match.group(1)
return None
@staticmethod
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
"""Wrap raw PCM16 mono bytes into a WAV container."""
buffer = io.BytesIO()
with wave.open(buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
return buffer.getvalue()
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
if not self.api_key:
raise ValueError("Azure API key required for STT")
if not self.speech_region:
raise ValueError("Azure speech region required for STT")
normalized_format = audio_format.lower()
payload = audio_data
content_type = f"audio/{normalized_format}"
# WebSocket chunked fallback sends raw PCM16 bytes.
if normalized_format in {"pcm", "pcm16", "raw"}:
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
content_type = "audio/wav"
elif normalized_format in {"wav", "wave"}:
content_type = "audio/wav"
elif normalized_format == "webm":
content_type = "audio/webm; codecs=opus"
url = (
f"https://{self.speech_region}.stt.speech.microsoft.com/"
"speech/recognition/conversation/cognitiveservices/v1"
)
params = {"language": "en-US", "format": "detailed"}
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": content_type,
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(
url, params=params, headers=headers, data=payload
) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure STT failed: {error_text}")
result = await response.json()
if result.get("RecognitionStatus") != "Success":
return ""
nbest = result.get("NBest") or []
if nbest and isinstance(nbest, list):
display = nbest[0].get("Display")
if isinstance(display, str):
return display
display_text = result.get("DisplayText", "")
return display_text if isinstance(display_text, str) else ""
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using Azure TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice name (defaults to provider's default voice)
speed: Playback speed multiplier (0.5 to 2.0)
Yields:
Audio data chunks (mp3 format)
"""
if not self.api_key:
raise ValueError("Azure API key required for TTS")
if not self.speech_region:
raise ValueError("Azure speech region required for TTS")
voice_name = voice or self.default_voice
# Clamp speed to valid range and convert to rate format
speed = max(0.5, min(2.0, speed))
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
# Build SSML with escaped text and quoted attributes to prevent injection
escaped_text = escape(text)
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
<voice name={quoteattr(voice_name)}>
<prosody rate='{rate}'>{escaped_text}</prosody>
</voice>
</speak>"""
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
headers = {
"Ocp-Apim-Subscription-Key": self.api_key,
"Content-Type": "application/ssml+xml",
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
"User-Agent": "Onyx",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=ssml) as response:
if response.status != 200:
error_text = await response.text()
raise RuntimeError(f"Azure TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
async for chunk in response.content.iter_chunked(8192):
if chunk:
yield chunk
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common Azure Neural voices."""
return AZURE_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "default", "name": "Azure Speech Recognition"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "neural", "name": "Neural TTS"},
]
def supports_streaming_stt(self) -> bool:
"""Azure supports streaming STT via Speech SDK."""
return True
def supports_streaming_tts(self) -> bool:
"""Azure supports real-time streaming TTS via Speech SDK."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> AzureStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
if not self.speech_region:
raise ValueError("Speech region required for Azure streaming transcription")
transcriber = AzureStreamingTranscriber(
api_key=self.api_key,
region=self.speech_region,
input_sample_rate=24000,
target_sample_rate=16000,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> AzureStreamingSynthesizer:
"""Create a streaming TTS session."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
if not self.speech_region:
raise ValueError("Speech region required for Azure streaming TTS")
synthesizer = AzureStreamingSynthesizer(
api_key=self.api_key,
region=self.speech_region,
voice=voice or self.default_voice or "en-US-JennyNeural",
speed=speed,
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,766 @@
import asyncio
import base64
import json
from collections.abc import AsyncIterator
from typing import Any
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
# Common ElevenLabs voices
ELEVENLABS_VOICES = [
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
]
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
def __init__(
self,
api_key: str,
model: str = "scribe_v2_realtime",
input_sample_rate: int = 24000, # What frontend sends
target_sample_rate: int = 16000, # What ElevenLabs expects
language_code: str = "en",
):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self.input_sample_rate = input_sample_rate
self.target_sample_rate = target_sample_rate
self.language_code = language_code
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._final_transcript = ""
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs."""
self._logger.info(
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
)
self._session = aiohttp.ClientSession()
# VAD is configured via query parameters
# commit_strategy=vad enables automatic transcript commit on silence detection
url = (
f"wss://api.elevenlabs.io/v1/speech-to-text/realtime"
f"?model_id={self.model}"
f"&sample_rate={self.target_sample_rate}"
f"&language_code={self.language_code}"
f"&commit_strategy=vad"
f"&vad_silence_threshold_secs=1.0"
f"&vad_threshold=0.4"
f"&min_speech_duration_ms=100"
f"&min_silence_duration_ms=300"
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connecting to {url} "
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
)
try:
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
self._logger.info(
f"ElevenLabsStreamingTranscriber: connected successfully, "
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
)
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
)
if self._session:
await self._session.close()
raise
# Start receiving transcripts in background
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts from WebSocket."""
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
if not self._ws:
self._logger.warning(
"ElevenLabsStreamingTranscriber: no WebSocket connection"
)
return
try:
async for msg in self._ws:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
)
if msg.type == aiohttp.WSMsgType.TEXT:
parsed_data: Any = None
data: dict[str, Any]
try:
parsed_data = json.loads(msg.data)
except json.JSONDecodeError:
self._logger.error(
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
)
continue
if not isinstance(parsed_data, dict):
self._logger.error(
"ElevenLabsStreamingTranscriber: expected object JSON payload"
)
continue
data = parsed_data
# ElevenLabs uses message_type field - fail fast if missing
if "message_type" not in data and "type" not in data:
self._logger.error(
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
)
continue
msg_type = data.get("message_type", data.get("type", ""))
self._logger.info(
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
)
# Check for error in various formats
if "error" in data or msg_type == "error":
error_msg = data.get("error", data.get("message", data))
self._logger.error(
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
)
continue
# Handle different message types from ElevenLabs Scribe API
if msg_type == "session_started":
# Session started successfully
self._logger.info(
f"ElevenLabsStreamingTranscriber: session started, "
f"id={data.get('session_id')}, config={data.get('config')}"
)
elif msg_type == "partial_transcript":
# Partial transcript (interim result)
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=False)
)
elif msg_type == "committed_transcript":
# Final/committed transcript (VAD detected end of utterance)
text = data.get("text", "")
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == "utterance_end":
# VAD detected end of speech
text = data.get("text", "") or self._final_transcript
if text:
self._logger.info(
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
)
self._final_transcript = text
await self._transcript_queue.put(
TranscriptResult(text=text, is_vad_end=True)
)
elif msg_type == "session_ended":
self._logger.info(
"ElevenLabsStreamingTranscriber: session ended"
)
break
else:
# Log unhandled message types with full data for debugging
self._logger.warning(
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
self._logger.debug(
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
)
elif msg.type == aiohttp.WSMsgType.CLOSED:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
"ElevenLabsStreamingTranscriber: WebSocket closed by "
f"server, close_code={close_code}"
)
break
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
)
break
elif msg.type == aiohttp.WSMsgType.CLOSE:
self._logger.info(
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
)
break
except Exception as e:
self._logger.error(
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
exc_info=True,
)
finally:
close_code = self._ws.close_code if self._ws else "N/A"
self._logger.info(
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
)
await self._transcript_queue.put(None) # Signal end
def _resample_pcm16(self, data: bytes) -> bytes:
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
import struct
if self.input_sample_rate == self.target_sample_rate:
return data
# Parse int16 samples
num_samples = len(data) // 2
samples = list(struct.unpack(f"<{num_samples}h", data))
# Calculate resampling ratio
ratio = self.input_sample_rate / self.target_sample_rate
new_length = int(num_samples / ratio)
# Linear interpolation resampling
resampled = []
for i in range(new_length):
src_idx = i * ratio
idx_floor = int(src_idx)
idx_ceil = min(idx_floor + 1, num_samples - 1)
frac = src_idx - idx_floor
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
# Clamp to int16 range
sample = max(-32768, min(32767, sample))
resampled.append(sample)
return struct.pack(f"<{len(resampled)}h", *resampled)
async def send_audio(self, chunk: bytes) -> None:
"""Send an audio chunk for transcription."""
if not self._ws:
self._logger.warning("send_audio: no WebSocket connection")
return
if self._closed:
self._logger.warning("send_audio: transcriber is closed")
return
if self._ws.closed:
self._logger.warning(
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
)
return
try:
# Resample from input rate (24kHz) to target rate (16kHz)
resampled = self._resample_pcm16(chunk)
# ElevenLabs expects input_audio_chunk message format with audio_base_64
audio_b64 = base64.b64encode(resampled).decode("utf-8")
message = {
"message_type": "input_audio_chunk",
"audio_base_64": audio_b64,
"sample_rate": self.target_sample_rate,
}
self._logger.info(
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
)
await self._ws.send_str(json.dumps(message))
self._logger.info("send_audio: message sent successfully")
except Exception as e:
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
raise
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript. Returns None when done."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(
text="", is_vad_end=False
) # No transcript yet, but not done
async def close(self) -> str:
"""Close the session and return final transcript."""
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
self._closed = True
if self._ws and not self._ws.closed:
try:
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
self._logger.info(
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
)
await self._ws.close()
except Exception as e:
self._logger.debug(f"Error closing WebSocket: {e}")
if self._receive_task and not self._receive_task.done():
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session and not self._session.closed:
await self._session.close()
return self._final_transcript
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._final_transcript = ""
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Real-time streaming TTS using ElevenLabs WebSocket API.
Uses ElevenLabs' stream-input WebSocket which processes text as one
continuous stream and returns audio in order.
"""
def __init__(
self,
api_key: str,
voice_id: str,
model_id: str = "eleven_multilingual_v2",
output_format: str = "mp3_44100_64",
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice_id = voice_id
self.model_id = model_id
self.output_format = output_format
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to ElevenLabs TTS."""
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# WebSocket URL for streaming input TTS with output format for streaming compatibility
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
url = (
f"wss://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream-input"
f"?model_id={self.model_id}&output_format={self.output_format}"
)
self._ws = await self._session.ws_connect(
url,
headers={"xi-api-key": self.api_key},
)
# Send initial configuration with generation settings optimized for streaming
await self._ws.send_str(
json.dumps(
{
"text": " ", # Initial space to start the stream
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
"generation_config": {
"chunk_length_schedule": [
120,
160,
250,
290,
], # Optimized chunk sizes for streaming
},
"xi_api_key": self.api_key,
}
)
)
# Start receiving audio in background
self._receive_task = asyncio.create_task(self._receive_loop())
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
async def _receive_loop(self) -> None:
"""Background task to receive audio chunks from WebSocket.
Audio is returned in order as one continuous stream.
"""
if not self._ws:
return
chunk_count = 0
total_bytes = 0
try:
async for msg in self._ws:
if self._closed:
self._logger.info(
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
"receive loop"
)
break
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
# Process audio if present
if "audio" in data and data["audio"]:
audio_bytes = base64.b64decode(data["audio"])
chunk_count += 1
total_bytes += len(audio_bytes)
await self._audio_queue.put(audio_bytes)
# Check isFinal separately - a message can have both audio AND isFinal
if "isFinal" in data:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
)
if data.get("isFinal"):
self._logger.info(
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
)
await self._audio_queue.put(None)
# Check for errors
if "error" in data or data.get("type") == "error":
self._logger.error(
f"ElevenLabsStreamingSynthesizer: received error: {data}"
)
elif msg.type == aiohttp.WSMsgType.BINARY:
chunk_count += 1
total_bytes += len(msg.data)
await self._audio_queue.put(msg.data)
elif msg.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
):
self._logger.info(
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
)
break
except Exception as e:
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
finally:
self._logger.info(
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
)
await self._audio_queue.put(None) # Signal end of stream
async def send_text(self, text: str) -> None:
"""Send text to be synthesized.
ElevenLabs processes text as a continuous stream and returns
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
and only force generation when flush() is called at the end.
Args:
text: Text to synthesize
"""
if self._ws and not self._closed and text.strip():
self._logger.info(
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
)
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
# Don't trigger generation here - wait for flush() at the end
await self._ws.send_str(
json.dumps(
{
"text": text + " ", # Space for natural speech flow
}
)
)
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
if self._ws and not self._closed:
# Send empty string to signal end of input
# ElevenLabs will generate any remaining buffered text,
# send all audio chunks, send isFinal, then close the connection
self._logger.info(
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
)
await self._ws.send_str(json.dumps({"text": ""}))
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
else:
self._logger.warning(
f"ElevenLabsStreamingSynthesizer: skipping flush - "
f"ws={self._ws is not None}, closed={self._closed}"
)
async def close(self) -> None:
"""Close the session."""
self._closed = True
if self._ws:
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
# Valid ElevenLabs model IDs
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
ELEVENLABS_TTS_MODELS = {
"eleven_multilingual_v2",
"eleven_turbo_v2_5",
"eleven_monolingual_v1",
"eleven_flash_v2_5",
"eleven_flash_v2",
}
class ElevenLabsVoiceProvider(VoiceProviderInterface):
"""ElevenLabs voice provider."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base or "https://api.elevenlabs.io"
# Validate and default models - use valid ElevenLabs model IDs
self.stt_model = (
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
)
self.tts_model = (
tts_model
if tts_model in ELEVENLABS_TTS_MODELS
else "eleven_multilingual_v2"
)
self.default_voice = default_voice
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using ElevenLabs Speech-to-Text API.
Args:
audio_data: Raw audio bytes
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
Returns:
Transcribed text
"""
if not self.api_key:
raise ValueError("ElevenLabs API key required for transcription")
from onyx.utils.logger import setup_logger
logger = setup_logger()
url = f"{self.api_base}/v1/speech-to-text"
# Map common formats to MIME types
mime_types = {
"webm": "audio/webm",
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"m4a": "audio/mp4",
}
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
headers = {
"xi-api-key": self.api_key,
}
# ElevenLabs expects multipart form data
form_data = aiohttp.FormData()
form_data.add_field(
"audio",
audio_data,
filename=f"audio.{audio_format}",
content_type=mime_type,
)
# For batch STT, use scribe_v1 (not the realtime model)
batch_model = (
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
)
form_data.add_field("model_id", batch_model)
logger.info(
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
)
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, data=form_data) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs transcribe failed: {error_text}")
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
result = await response.json()
text = result.get("text", "")
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
return text
async def synthesize_stream(
self, text: str, voice: str | None = None, _speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using ElevenLabs TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice ID (defaults to provider's default voice or Rachel)
speed: Playback speed multiplier (not directly supported, ignored)
Yields:
Audio data chunks (mp3 format)
"""
from onyx.utils.logger import setup_logger
logger = setup_logger()
if not self.api_key:
raise ValueError("ElevenLabs API key required for TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
logger.info(
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
f"voice={voice_id}, model={self.tts_model}"
)
headers = {
"xi-api-key": self.api_key,
"Content-Type": "application/json",
"Accept": "audio/mpeg",
}
payload = {
"text": text,
"model_id": self.tts_model,
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75,
},
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as response:
logger.info(
f"ElevenLabs TTS: got response status={response.status}, "
f"content-type={response.headers.get('content-type')}"
)
if response.status != 200:
error_text = await response.text()
logger.error(f"ElevenLabs TTS failed: {error_text}")
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
# Use 8192 byte chunks for smoother streaming
chunk_count = 0
total_bytes = 0
async for chunk in response.content.iter_chunked(8192):
if chunk:
chunk_count += 1
total_bytes += len(chunk)
yield chunk
logger.info(
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
f"{total_bytes} total bytes"
)
def get_available_voices(self) -> list[dict[str, str]]:
"""Return common ElevenLabs voices."""
return ELEVENLABS_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
return [
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
]
def get_available_tts_models(self) -> list[dict[str, str]]:
return [
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
]
def supports_streaming_stt(self) -> bool:
"""ElevenLabs supports streaming via Scribe Realtime API."""
return True
def supports_streaming_tts(self) -> bool:
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> ElevenLabsStreamingTranscriber:
"""Create a streaming transcription session."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
# ElevenLabs realtime STT requires scribe_v2_realtime model
# Frontend sends PCM16 at 24kHz, but ElevenLabs expects 16kHz
# The transcriber will resample automatically
transcriber = ElevenLabsStreamingTranscriber(
api_key=self.api_key,
model="scribe_v2_realtime",
input_sample_rate=24000, # What frontend sends
target_sample_rate=16000, # What ElevenLabs expects
language_code="en",
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, _speed: float = 1.0
) -> ElevenLabsStreamingSynthesizer:
"""Create a streaming TTS session."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
synthesizer = ElevenLabsStreamingSynthesizer(
api_key=self.api_key,
voice_id=voice_id,
model_id=self.tts_model,
# Use mp3_44100_64 for streaming - good balance of quality and chunk size
output_format="mp3_44100_64",
)
await synthesizer.connect()
return synthesizer

View File

@@ -0,0 +1,567 @@
import asyncio
import base64
import io
import json
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
import aiohttp
from onyx.voice.interface import StreamingSynthesizerProtocol
from onyx.voice.interface import StreamingTranscriberProtocol
from onyx.voice.interface import TranscriptResult
from onyx.voice.interface import VoiceProviderInterface
if TYPE_CHECKING:
from openai import AsyncOpenAI
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
"""Streaming transcription using OpenAI Realtime API."""
def __init__(self, api_key: str, model: str = "whisper-1"):
# Import logger first
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self._logger.info(
f"OpenAIStreamingTranscriber: initializing with model {model}"
)
self.api_key = api_key
self.model = model
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._session: aiohttp.ClientSession | None = None
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
self._current_turn_transcript = "" # Transcript for current VAD turn
self._accumulated_transcript = "" # Accumulated across all turns
self._receive_task: asyncio.Task | None = None
self._closed = False
async def connect(self) -> None:
"""Establish WebSocket connection to OpenAI Realtime API."""
self._session = aiohttp.ClientSession()
# OpenAI Realtime transcription endpoint
url = "wss://api.openai.com/v1/realtime?intent=transcription"
headers = {
"Authorization": f"Bearer {self.api_key}",
"OpenAI-Beta": "realtime=v1",
}
try:
self._ws = await self._session.ws_connect(url, headers=headers)
self._logger.info("Connected to OpenAI Realtime API")
except Exception as e:
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
raise
# Configure the session for transcription
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
config_message = {
"type": "transcription_session.update",
"session": {
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
"input_audio_transcription": {
"model": self.model,
},
"turn_detection": {
"type": "server_vad",
"threshold": 0.5,
"prefix_padding_ms": 300,
"silence_duration_ms": 500,
},
},
}
await self._ws.send_str(json.dumps(config_message))
self._logger.info(f"Sent config for model: {self.model} with server VAD")
# Start receiving transcripts
self._receive_task = asyncio.create_task(self._receive_loop())
async def _receive_loop(self) -> None:
"""Background task to receive transcripts."""
if not self._ws:
return
try:
async for msg in self._ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
msg_type = data.get("type", "")
self._logger.debug(f"Received message type: {msg_type}")
# Handle errors
if msg_type == "error":
error = data.get("error", {})
self._logger.error(f"OpenAI error: {error}")
continue
# Handle VAD events
if msg_type == "input_audio_buffer.speech_started":
self._logger.info("OpenAI: Speech started")
# Reset current turn transcript for new speech
self._current_turn_transcript = ""
continue
elif msg_type == "input_audio_buffer.speech_stopped":
self._logger.info(
"OpenAI: Speech stopped (VAD detected silence)"
)
continue
elif msg_type == "input_audio_buffer.committed":
self._logger.info("OpenAI: Audio buffer committed")
continue
# Handle transcription events
if msg_type == "conversation.item.input_audio_transcription.delta":
delta = data.get("delta", "")
if delta:
self._logger.info(f"OpenAI: Transcription delta: {delta}")
self._current_turn_transcript += delta
# Show accumulated + current turn transcript
full_transcript = self._accumulated_transcript
if full_transcript and self._current_turn_transcript:
full_transcript += " "
full_transcript += self._current_turn_transcript
await self._transcript_queue.put(
TranscriptResult(text=full_transcript, is_vad_end=False)
)
elif (
msg_type
== "conversation.item.input_audio_transcription.completed"
):
transcript = data.get("transcript", "")
if transcript:
self._logger.info(
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
)
# This is the final transcript for this VAD turn
self._current_turn_transcript = transcript
# Accumulate this turn's transcript
if self._accumulated_transcript:
self._accumulated_transcript += " " + transcript
else:
self._accumulated_transcript = transcript
# Send with is_vad_end=True to trigger auto-send
await self._transcript_queue.put(
TranscriptResult(
text=self._accumulated_transcript,
is_vad_end=True,
)
)
elif msg_type not in (
"transcription_session.created",
"transcription_session.updated",
"conversation.item.created",
):
# Log any other message types we might be missing
self._logger.info(
f"OpenAI: Unhandled message type '{msg_type}': {data}"
)
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.error(f"WebSocket error: {self._ws.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
self._logger.info("WebSocket closed by server")
break
except Exception as e:
self._logger.error(f"Error in receive loop: {e}")
finally:
await self._transcript_queue.put(None)
async def send_audio(self, chunk: bytes) -> None:
"""Send audio chunk to OpenAI."""
if self._ws and not self._closed:
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
# So chunk_bytes / 48000 = duration in seconds
duration_ms = (len(chunk) / 48000) * 1000
self._logger.debug(
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
)
message = {
"type": "input_audio_buffer.append",
"audio": base64.b64encode(chunk).decode("utf-8"),
}
await self._ws.send_str(json.dumps(message))
def reset_transcript(self) -> None:
"""Reset accumulated transcript. Call after auto-send to start fresh."""
self._logger.info("OpenAI: Resetting accumulated transcript")
self._accumulated_transcript = ""
self._current_turn_transcript = ""
async def receive_transcript(self) -> TranscriptResult | None:
"""Receive next transcript."""
try:
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return TranscriptResult(text="", is_vad_end=False)
async def close(self) -> str:
"""Close session and return final transcript."""
self._closed = True
if self._ws:
# With server VAD, the buffer is auto-committed when speech stops.
# But we should still commit any remaining audio and wait for transcription.
try:
await self._ws.send_str(
json.dumps({"type": "input_audio_buffer.commit"})
)
except Exception as e:
self._logger.debug(f"Error sending commit (may be expected): {e}")
# Wait for transcription to arrive (up to 5 seconds)
self._logger.info("Waiting for transcription to complete...")
for _ in range(50): # 50 * 100ms = 5 seconds max
await asyncio.sleep(0.1)
if self._accumulated_transcript:
self._logger.info(
f"Got final transcript: {self._accumulated_transcript[:50]}..."
)
break
else:
self._logger.warning("Timed out waiting for transcription")
await self._ws.close()
if self._receive_task:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
return self._accumulated_transcript
# OpenAI available voices for TTS
OPENAI_VOICES = [
{"id": "alloy", "name": "Alloy"},
{"id": "echo", "name": "Echo"},
{"id": "fable", "name": "Fable"},
{"id": "onyx", "name": "Onyx"},
{"id": "nova", "name": "Nova"},
{"id": "shimmer", "name": "Shimmer"},
]
# OpenAI available STT models (all support streaming via Realtime API)
OPENAI_STT_MODELS = [
{"id": "whisper-1", "name": "Whisper v1"},
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
]
# OpenAI available TTS models
OPENAI_TTS_MODELS = [
{"id": "tts-1", "name": "TTS-1 (Standard)"},
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
]
def _create_wav_header(
data_length: int,
sample_rate: int = 24000,
channels: int = 1,
bits_per_sample: int = 16,
) -> bytes:
"""Create a WAV file header for PCM audio data."""
import struct
byte_rate = sample_rate * channels * bits_per_sample // 8
block_align = channels * bits_per_sample // 8
# WAV header is 44 bytes
header = struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", # ChunkID
36 + data_length, # ChunkSize
b"WAVE", # Format
b"fmt ", # Subchunk1ID
16, # Subchunk1Size (PCM)
1, # AudioFormat (1 = PCM)
channels, # NumChannels
sample_rate, # SampleRate
byte_rate, # ByteRate
block_align, # BlockAlign
bits_per_sample, # BitsPerSample
b"data", # Subchunk2ID
data_length, # Subchunk2Size
)
return header
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
def __init__(
self,
api_key: str,
voice: str = "alloy",
model: str = "tts-1",
speed: float = 1.0,
):
from onyx.utils.logger import setup_logger
self._logger = setup_logger()
self.api_key = api_key
self.voice = voice
self.model = model
self.speed = max(0.25, min(4.0, speed))
self._session: aiohttp.ClientSession | None = None
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
self._synthesis_task: asyncio.Task | None = None
self._closed = False
self._flushed = False
async def connect(self) -> None:
"""Initialize HTTP session for TTS requests."""
self._logger.info("OpenAIStreamingSynthesizer: connecting")
self._session = aiohttp.ClientSession()
# Start background task to process text queue
self._synthesis_task = asyncio.create_task(self._process_text_queue())
self._logger.info("OpenAIStreamingSynthesizer: connected")
async def _process_text_queue(self) -> None:
"""Background task to process queued text for synthesis."""
while not self._closed:
try:
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
if text is None:
break
await self._synthesize_text(text)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
except Exception as e:
self._logger.error(f"Error processing text queue: {e}")
async def _synthesize_text(self, text: str) -> None:
"""Make HTTP TTS request and stream audio to queue."""
if not self._session or self._closed:
return
url = "https://api.openai.com/v1/audio/speech"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"voice": self.voice,
"input": text,
"speed": self.speed,
"response_format": "mp3",
}
try:
async with self._session.post(
url, headers=headers, json=payload
) as response:
if response.status != 200:
error_text = await response.text()
self._logger.error(f"OpenAI TTS error: {error_text}")
return
# Use 8192 byte chunks for smoother streaming
# (larger chunks = more complete MP3 frames, better playback)
async for chunk in response.content.iter_chunked(8192):
if self._closed:
break
if chunk:
await self._audio_queue.put(chunk)
except Exception as e:
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
async def send_text(self, text: str) -> None:
"""Queue text to be synthesized via HTTP streaming."""
if not text.strip() or self._closed:
return
await self._text_queue.put(text)
async def receive_audio(self) -> bytes | None:
"""Receive next audio chunk (MP3 format)."""
try:
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return b"" # No audio yet, but not done
async def flush(self) -> None:
"""Signal end of text input - wait for synthesis to complete."""
if self._flushed:
return
self._flushed = True
# Signal end of text input
await self._text_queue.put(None)
# Wait for synthesis task to complete processing all text
if self._synthesis_task and not self._synthesis_task.done():
try:
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
except asyncio.TimeoutError:
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
except asyncio.CancelledError:
pass
# Signal end of audio stream
await self._audio_queue.put(None)
async def close(self) -> None:
"""Close the session."""
if self._closed:
return
self._closed = True
# Signal end of queues only if flush wasn't already called
if not self._flushed:
await self._text_queue.put(None)
await self._audio_queue.put(None)
if self._synthesis_task and not self._synthesis_task.done():
self._synthesis_task.cancel()
try:
await self._synthesis_task
except asyncio.CancelledError:
pass
if self._session:
await self._session.close()
class OpenAIVoiceProvider(VoiceProviderInterface):
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
def __init__(
self,
api_key: str | None,
api_base: str | None = None,
stt_model: str | None = None,
tts_model: str | None = None,
default_voice: str | None = None,
):
self.api_key = api_key
self.api_base = api_base
self.stt_model = stt_model or "whisper-1"
self.tts_model = tts_model or "tts-1"
self.default_voice = default_voice or "alloy"
self._client: "AsyncOpenAI | None" = None
def _get_client(self) -> "AsyncOpenAI":
if self._client is None:
from openai import AsyncOpenAI
self._client = AsyncOpenAI(
api_key=self.api_key,
base_url=self.api_base,
)
return self._client
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
"""
Transcribe audio using OpenAI Whisper.
Args:
audio_data: Raw audio bytes
audio_format: Audio format (e.g., "webm", "wav", "mp3")
Returns:
Transcribed text
"""
client = self._get_client()
# Create a file-like object from the audio bytes
audio_file = io.BytesIO(audio_data)
audio_file.name = f"audio.{audio_format}"
response = await client.audio.transcriptions.create(
model=self.stt_model,
file=audio_file,
)
return response.text
async def synthesize_stream(
self, text: str, voice: str | None = None, speed: float = 1.0
) -> AsyncIterator[bytes]:
"""
Convert text to audio using OpenAI TTS with streaming.
Args:
text: Text to convert to speech
voice: Voice identifier (defaults to provider's default voice)
speed: Playback speed multiplier (0.25 to 4.0)
Yields:
Audio data chunks (mp3 format)
"""
client = self._get_client()
# Clamp speed to valid range
speed = max(0.25, min(4.0, speed))
# Use with_streaming_response for proper async streaming
# Using 8192 byte chunks for better streaming performance
# (larger chunks = fewer round-trips, more complete MP3 frames)
async with client.audio.speech.with_streaming_response.create(
model=self.tts_model,
voice=voice or self.default_voice,
input=text,
speed=speed,
response_format="mp3",
) as response:
async for chunk in response.iter_bytes(chunk_size=8192):
yield chunk
def get_available_voices(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS voices."""
return OPENAI_VOICES.copy()
def get_available_stt_models(self) -> list[dict[str, str]]:
"""Get available OpenAI STT models."""
return OPENAI_STT_MODELS.copy()
def get_available_tts_models(self) -> list[dict[str, str]]:
"""Get available OpenAI TTS models."""
return OPENAI_TTS_MODELS.copy()
def supports_streaming_stt(self) -> bool:
"""OpenAI supports streaming via Realtime API for all STT models."""
return True
def supports_streaming_tts(self) -> bool:
"""OpenAI supports real-time streaming TTS via Realtime API."""
return True
async def create_streaming_transcriber(
self, _audio_format: str = "webm"
) -> OpenAIStreamingTranscriber:
"""Create a streaming transcription session using Realtime API."""
if not self.api_key:
raise ValueError("API key required for streaming transcription")
transcriber = OpenAIStreamingTranscriber(
api_key=self.api_key,
model=self.stt_model,
)
await transcriber.connect()
return transcriber
async def create_streaming_synthesizer(
self, voice: str | None = None, speed: float = 1.0
) -> OpenAIStreamingSynthesizer:
"""Create a streaming TTS session using HTTP streaming API."""
if not self.api_key:
raise ValueError("API key required for streaming TTS")
synthesizer = OpenAIStreamingSynthesizer(
api_key=self.api_key,
voice=voice or self.default_voice or "alloy",
model=self.tts_model or "tts-1",
speed=speed,
)
await synthesizer.connect()
return synthesizer

View File

@@ -596,7 +596,7 @@ mypy-extensions==1.0.0
# typing-inspect
nest-asyncio==1.6.0
# via onyx
nltk==3.9.1
nltk==3.9.3
# via unstructured
numpy==2.4.1
# via

View File

@@ -16,10 +16,6 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None:
def run_jobs() -> None:
# Check if we should use lightweight mode, defaults to True, change to False to use separate background workers
use_lightweight = True
# command setup
cmd_worker_primary = [
"celery",
"-A",
@@ -74,6 +70,48 @@ def run_jobs() -> None:
"--queues=connector_doc_fetching",
]
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,user_file_delete",
]
cmd_beat = [
"celery",
"-A",
@@ -82,144 +120,31 @@ def run_jobs() -> None:
"--loglevel=INFO",
]
# Prepare background worker commands based on mode
if use_lightweight:
print("Starting workers in LIGHTWEIGHT mode (single background worker)")
cmd_worker_background = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.background",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration",
]
background_workers = [("BACKGROUND", cmd_worker_background)]
else:
print("Starting workers in STANDARD mode (separate background workers)")
cmd_worker_heavy = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning,sandbox",
]
cmd_worker_monitoring = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
]
cmd_worker_user_file_processing = [
"celery",
"-A",
"onyx.background.celery.versioned_apps.user_file_processing",
"worker",
"--pool=threads",
"--concurrency=2",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=user_file_processing@%n",
"-Q",
"user_file_processing,user_file_project_sync,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,user_file_delete",
]
background_workers = [
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
]
all_workers = [
("PRIMARY", cmd_worker_primary),
("LIGHT", cmd_worker_light),
("DOCPROCESSING", cmd_worker_docprocessing),
("DOCFETCHING", cmd_worker_docfetching),
("HEAVY", cmd_worker_heavy),
("MONITORING", cmd_worker_monitoring),
("USER_FILE_PROCESSING", cmd_worker_user_file_processing),
("BEAT", cmd_beat),
]
# spawn processes
worker_primary_process = subprocess.Popen(
cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_light_process = subprocess.Popen(
cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
worker_docprocessing_process = subprocess.Popen(
cmd_worker_docprocessing,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
worker_docfetching_process = subprocess.Popen(
cmd_worker_docfetching,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Spawn background worker processes based on mode
background_processes = []
for name, cmd in background_workers:
processes = []
for name, cmd in all_workers:
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
background_processes.append((name, process))
processes.append((name, process))
# monitor threads
worker_primary_thread = threading.Thread(
target=monitor_process, args=("PRIMARY", worker_primary_process)
)
worker_light_thread = threading.Thread(
target=monitor_process, args=("LIGHT", worker_light_process)
)
worker_docprocessing_thread = threading.Thread(
target=monitor_process, args=("DOCPROCESSING", worker_docprocessing_process)
)
worker_docfetching_thread = threading.Thread(
target=monitor_process, args=("DOCFETCHING", worker_docfetching_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
# Create monitor threads for background workers
background_threads = []
for name, process in background_processes:
threads = []
for name, process in processes:
thread = threading.Thread(target=monitor_process, args=(name, process))
background_threads.append(thread)
# Start all threads
worker_primary_thread.start()
worker_light_thread.start()
worker_docprocessing_thread.start()
worker_docfetching_thread.start()
beat_thread.start()
for thread in background_threads:
threads.append(thread)
thread.start()
# Wait for all threads
worker_primary_thread.join()
worker_light_thread.join()
worker_docprocessing_thread.join()
worker_docfetching_thread.join()
beat_thread.join()
for thread in background_threads:
for thread in threads:
thread.join()

View File

@@ -1,10 +1,20 @@
#!/bin/bash
set -e
cleanup() {
echo "Error occurred. Cleaning up..."
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
COMPOSE_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.yml"
COMPOSE_DEV_FILE="$SCRIPT_DIR/../../deployment/docker_compose/docker-compose.dev.yml"
stop_and_remove_containers() {
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled stop opensearch 2>/dev/null || true
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled rm -f opensearch 2>/dev/null || true
}
cleanup() {
echo "Error occurred. Cleaning up..."
stop_and_remove_containers
}
# Trap errors and output a message, then cleanup
@@ -12,16 +22,26 @@ trap 'echo "Error occurred on line $LINENO. Exiting script." >&2; cleanup' ERR
# Usage of the script with optional volume arguments
# ./restart_containers.sh [vespa_volume] [postgres_volume] [redis_volume]
# [minio_volume] [--keep-opensearch-data]
VESPA_VOLUME=${1:-""} # Default is empty if not provided
POSTGRES_VOLUME=${2:-""} # Default is empty if not provided
REDIS_VOLUME=${3:-""} # Default is empty if not provided
MINIO_VOLUME=${4:-""} # Default is empty if not provided
KEEP_OPENSEARCH_DATA=false
POSITIONAL_ARGS=()
for arg in "$@"; do
if [[ "$arg" == "--keep-opensearch-data" ]]; then
KEEP_OPENSEARCH_DATA=true
else
POSITIONAL_ARGS+=("$arg")
fi
done
VESPA_VOLUME=${POSITIONAL_ARGS[0]:-""}
POSTGRES_VOLUME=${POSITIONAL_ARGS[1]:-""}
REDIS_VOLUME=${POSITIONAL_ARGS[2]:-""}
MINIO_VOLUME=${POSITIONAL_ARGS[3]:-""}
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
stop_and_remove_containers
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -39,6 +59,29 @@ else
docker run --detach --name onyx_vespa --hostname vespa-container --publish 8081:8081 --publish 19071:19071 vespaengine/vespa:8
fi
# If OPENSEARCH_ADMIN_PASSWORD is not already set, try loading it from
# .vscode/.env so existing dev setups that stored it there aren't silently
# broken.
VSCODE_ENV="$SCRIPT_DIR/../../.vscode/.env"
if [[ -z "${OPENSEARCH_ADMIN_PASSWORD:-}" && -f "$VSCODE_ENV" ]]; then
set -a
# shellcheck source=/dev/null
source "$VSCODE_ENV"
set +a
fi
# Start the OpenSearch container using the same service from docker-compose that
# our users use, setting OPENSEARCH_INITIAL_ADMIN_PASSWORD from the env's
# OPENSEARCH_ADMIN_PASSWORD if it exists, else defaulting to StrongPassword123!.
# Pass --keep-opensearch-data to preserve the opensearch-data volume across
# restarts, else the volume is deleted so the container starts fresh.
if [[ "$KEEP_OPENSEARCH_DATA" == "false" ]]; then
echo "Deleting opensearch-data volume..."
docker volume rm onyx_opensearch-data 2>/dev/null || true
fi
echo "Starting OpenSearch container..."
docker compose -f "$COMPOSE_FILE" -f "$COMPOSE_DEV_FILE" --profile opensearch-enabled up --force-recreate -d opensearch
# Start the Redis container with optional volume
echo "Starting Redis container..."
if [[ -n "$REDIS_VOLUME" ]]; then
@@ -60,7 +103,6 @@ echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
cd "$PARENT_DIR"

View File

@@ -1,10 +0,0 @@
#!/bin/bash
# We get OPENSEARCH_ADMIN_PASSWORD from the repo .env file.
source "$(dirname "$0")/../../.vscode/.env"
cd "$(dirname "$0")/../../deployment/docker_compose"
# Start OpenSearch.
echo "Forcefully starting fresh OpenSearch container..."
docker compose -f docker-compose.opensearch.yml up --force-recreate -d opensearch

View File

@@ -1,23 +1,5 @@
#!/bin/sh
# Entrypoint script for supervisord that sets environment variables
# for controlling which celery workers to start
# Default to lightweight mode if not set
if [ -z "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" ]; then
export USE_LIGHTWEIGHT_BACKGROUND_WORKER="true"
fi
# Set the complementary variable for supervisord
# because it doesn't support %(not ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER) syntax
if [ "$USE_LIGHTWEIGHT_BACKGROUND_WORKER" = "true" ]; then
export USE_SEPARATE_BACKGROUND_WORKERS="false"
else
export USE_SEPARATE_BACKGROUND_WORKERS="true"
fi
echo "Worker mode configuration:"
echo " USE_LIGHTWEIGHT_BACKGROUND_WORKER=$USE_LIGHTWEIGHT_BACKGROUND_WORKER"
echo " USE_SEPARATE_BACKGROUND_WORKERS=$USE_SEPARATE_BACKGROUND_WORKERS"
# Entrypoint script for supervisord
# Launch supervisord with environment variables available
exec /usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf

View File

@@ -39,7 +39,6 @@ autorestart=true
startsecs=10
stopasgroup=true
# Standard mode: Light worker for fast operations
# NOTE: only allowing configuration here and not in the other celery workers,
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
# user group syncing, deletion, etc.)
@@ -54,26 +53,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Lightweight mode: single consolidated background worker
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=true (default)
# Consolidates: light, docprocessing, docfetching, heavy, monitoring, user_file_processing
[program:celery_worker_background]
command=celery -A onyx.background.celery.versioned_apps.background worker
--loglevel=INFO
--hostname=background@%%n
-Q vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,sandbox,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,monitoring,user_file_processing,user_file_project_sync,opensearch_migration
stdout_logfile=/var/log/celery_worker_background.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_LIGHTWEIGHT_BACKGROUND_WORKER)s
# Standard mode: separate workers for different background tasks
# Used when USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
[program:celery_worker_heavy]
command=celery -A onyx.background.celery.versioned_apps.heavy worker
--loglevel=INFO
@@ -85,9 +65,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document processing worker
[program:celery_worker_docprocessing]
command=celery -A onyx.background.celery.versioned_apps.docprocessing worker
--loglevel=INFO
@@ -99,7 +77,6 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_user_file_processing]
command=celery -A onyx.background.celery.versioned_apps.user_file_processing worker
@@ -112,9 +89,7 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Standard mode: Document fetching worker
[program:celery_worker_docfetching]
command=celery -A onyx.background.celery.versioned_apps.docfetching worker
--loglevel=INFO
@@ -126,7 +101,6 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
[program:celery_worker_monitoring]
command=celery -A onyx.background.celery.versioned_apps.monitoring worker
@@ -139,7 +113,6 @@ redirect_stderr=true
autorestart=true
startsecs=10
stopasgroup=true
autostart=%(ENV_USE_SEPARATE_BACKGROUND_WORKERS)s
# Job scheduler for periodic tasks
@@ -197,7 +170,6 @@ command=tail -qF
/var/log/celery_beat.log
/var/log/celery_worker_primary.log
/var/log/celery_worker_light.log
/var/log/celery_worker_background.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_docprocessing.log
/var/log/celery_worker_monitoring.log

View File

@@ -5,6 +5,8 @@ Verifies that:
1. extract_ids_from_runnable_connector correctly separates hierarchy nodes from doc IDs
2. Extracted hierarchy nodes are correctly upserted to Postgres via upsert_hierarchy_nodes_batch
3. Upserting is idempotent (running twice doesn't duplicate nodes)
4. Document-to-hierarchy-node linkage is updated during pruning
5. link_hierarchy_nodes_to_documents links nodes that are also documents
Uses a mock SlimConnectorWithPermSync that yields known hierarchy nodes and slim documents,
combined with a real PostgreSQL database for verifying persistence.
@@ -27,9 +29,13 @@ from onyx.db.enums import HierarchyNodeType
from onyx.db.hierarchy import ensure_source_node_exists
from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source
from onyx.db.hierarchy import get_hierarchy_node_by_raw_id
from onyx.db.hierarchy import link_hierarchy_nodes_to_documents
from onyx.db.hierarchy import update_document_parent_hierarchy_nodes
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import Document as DbDocument
from onyx.db.models import HierarchyNode as DBHierarchyNode
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.kg.models import KGStage
# ---------------------------------------------------------------------------
# Constants
@@ -89,8 +95,18 @@ def _make_hierarchy_nodes() -> list[PydanticHierarchyNode]:
]
DOC_PARENT_MAP = {
"msg-001": CHANNEL_A_ID,
"msg-002": CHANNEL_A_ID,
"msg-003": CHANNEL_B_ID,
}
def _make_slim_docs() -> list[SlimDocument | PydanticHierarchyNode]:
return [SlimDocument(id=doc_id) for doc_id in SLIM_DOC_IDS]
return [
SlimDocument(id=doc_id, parent_hierarchy_raw_node_id=DOC_PARENT_MAP.get(doc_id))
for doc_id in SLIM_DOC_IDS
]
class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
@@ -126,14 +142,31 @@ class MockSlimConnectorWithPermSync(SlimConnectorWithPermSync):
# ---------------------------------------------------------------------------
def _cleanup_test_hierarchy_nodes(db_session: Session) -> None:
"""Remove all hierarchy nodes for TEST_SOURCE to isolate tests."""
def _cleanup_test_data(db_session: Session) -> None:
"""Remove all test hierarchy nodes and documents to isolate tests."""
for doc_id in SLIM_DOC_IDS:
db_session.query(DbDocument).filter(DbDocument.id == doc_id).delete()
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == TEST_SOURCE
).delete()
db_session.commit()
def _create_test_documents(db_session: Session) -> list[DbDocument]:
"""Insert minimal Document rows for our test doc IDs."""
docs = []
for doc_id in SLIM_DOC_IDS:
doc = DbDocument(
id=doc_id,
semantic_id=doc_id,
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
docs.append(doc)
db_session.commit()
return docs
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@@ -147,14 +180,14 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
result = extract_ids_from_runnable_connector(connector, callback=None)
# Doc IDs should include both slim doc IDs and hierarchy node raw_node_ids
# (hierarchy node IDs are added to doc_ids so they aren't pruned)
# (hierarchy node IDs are added to raw_id_to_parent so they aren't pruned)
expected_ids = {
CHANNEL_A_ID,
CHANNEL_B_ID,
CHANNEL_C_ID,
*SLIM_DOC_IDS,
}
assert result.doc_ids == expected_ids
assert result.raw_id_to_parent.keys() == expected_ids
# Hierarchy nodes should be the 3 channels
assert len(result.hierarchy_nodes) == 3
@@ -165,7 +198,7 @@ def test_pruning_extracts_hierarchy_nodes(db_session: Session) -> None: # noqa:
def test_pruning_upserts_hierarchy_nodes_to_db(db_session: Session) -> None:
"""Full flow: extract hierarchy nodes from mock connector, upsert to Postgres,
then verify the DB state (node count, parent relationships, permissions)."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
# Step 1: ensure the SOURCE node exists (mirrors what the pruning task does)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -230,7 +263,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
) -> None:
"""When the connector's access type is PUBLIC, all hierarchy nodes must be
marked is_public=True regardless of their external_access settings."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -257,7 +290,7 @@ def test_pruning_upserts_hierarchy_nodes_public_connector(
def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
"""Upserting the same hierarchy nodes twice must not create duplicates.
The second call should update existing rows in place."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -295,7 +328,7 @@ def test_pruning_hierarchy_node_upsert_idempotency(db_session: Session) -> None:
def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> None:
"""Upserting a hierarchy node with changed fields should update the existing row."""
_cleanup_test_hierarchy_nodes(db_session)
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
@@ -342,3 +375,193 @@ def test_pruning_hierarchy_node_upsert_updates_fields(db_session: Session) -> No
assert db_node.is_public is True
assert db_node.external_user_emails is not None
assert set(db_node.external_user_emails) == {"new_user@example.com"}
# ---------------------------------------------------------------------------
# Document-to-hierarchy-node linkage tests
# ---------------------------------------------------------------------------
def test_extraction_preserves_parent_hierarchy_raw_node_id(
db_session: Session, # noqa: ARG001
) -> None:
"""extract_ids_from_runnable_connector should carry the
parent_hierarchy_raw_node_id from SlimDocument into the raw_id_to_parent dict."""
connector = MockSlimConnectorWithPermSync()
result = extract_ids_from_runnable_connector(connector, callback=None)
for doc_id, expected_parent in DOC_PARENT_MAP.items():
assert (
result.raw_id_to_parent[doc_id] == expected_parent
), f"raw_id_to_parent[{doc_id}] should be {expected_parent}"
# Hierarchy node entries have None parent (they aren't documents)
for channel_id in [CHANNEL_A_ID, CHANNEL_B_ID, CHANNEL_C_ID]:
assert result.raw_id_to_parent[channel_id] is None
def test_update_document_parent_hierarchy_nodes(db_session: Session) -> None:
"""update_document_parent_hierarchy_nodes should set
Document.parent_hierarchy_node_id for each document in the mapping."""
_cleanup_test_data(db_session)
source_node = ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
# Create documents with no parent set
docs = _create_test_documents(db_session)
for doc in docs:
assert doc.parent_hierarchy_node_id is None
# Build resolved map (same logic as _resolve_and_update_document_parents)
resolved: dict[str, int | None] = {}
for doc_id, raw_parent in DOC_PARENT_MAP.items():
resolved[doc_id] = node_id_by_raw.get(raw_parent, source_node.id)
updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert updated == len(SLIM_DOC_IDS)
# Verify each document now points to the correct hierarchy node
db_session.expire_all()
for doc_id, raw_parent in DOC_PARENT_MAP.items():
tmp_doc = db_session.get(DbDocument, doc_id)
assert tmp_doc is not None
doc = tmp_doc
expected_node_id = node_id_by_raw[raw_parent]
assert (
doc.parent_hierarchy_node_id == expected_node_id
), f"Document {doc_id} should point to node for {raw_parent}"
def test_update_document_parent_is_idempotent(db_session: Session) -> None:
"""Running update_document_parent_hierarchy_nodes a second time with the
same mapping should update zero rows."""
_cleanup_test_data(db_session)
ensure_source_node_exists(db_session, TEST_SOURCE, commit=True)
upserted = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=_make_hierarchy_nodes(),
source=TEST_SOURCE,
commit=True,
is_connector_public=False,
)
node_id_by_raw = {n.raw_node_id: n.id for n in upserted}
_create_test_documents(db_session)
resolved: dict[str, int | None] = {
doc_id: node_id_by_raw[raw_parent]
for doc_id, raw_parent in DOC_PARENT_MAP.items()
}
first_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert first_updated == len(SLIM_DOC_IDS)
second_updated = update_document_parent_hierarchy_nodes(
db_session=db_session,
doc_parent_map=resolved,
commit=True,
)
assert second_updated == 0
def test_link_hierarchy_nodes_to_documents_for_confluence(
db_session: Session,
) -> None:
"""For sources in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS (e.g. Confluence),
link_hierarchy_nodes_to_documents should set HierarchyNode.document_id
when a hierarchy node's raw_node_id matches a document ID."""
_cleanup_test_data(db_session)
confluence_source = DocumentSource.CONFLUENCE
# Clean up any existing Confluence hierarchy nodes
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
ensure_source_node_exists(db_session, confluence_source, commit=True)
# Create a hierarchy node whose raw_node_id matches a document ID
page_node_id = "confluence-page-123"
nodes = [
PydanticHierarchyNode(
raw_node_id=page_node_id,
raw_parent_id=None,
display_name="Test Page",
link="https://wiki.example.com/page/123",
node_type=HierarchyNodeType.PAGE,
),
]
upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=nodes,
source=confluence_source,
commit=True,
is_connector_public=False,
)
# Verify the node exists but has no document_id yet
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id is None
# Create a document with the same ID as the hierarchy node
doc = DbDocument(
id=page_node_id,
semantic_id="Test Page",
kg_stage=KGStage.NOT_STARTED,
)
db_session.add(doc)
db_session.commit()
# Link nodes to documents
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=[page_node_id],
source=confluence_source,
commit=True,
)
assert linked == 1
# Verify the hierarchy node now has document_id set
db_session.expire_all()
db_node = get_hierarchy_node_by_raw_id(db_session, page_node_id, confluence_source)
assert db_node is not None
assert db_node.document_id == page_node_id
# Cleanup
db_session.query(DbDocument).filter(DbDocument.id == page_node_id).delete()
db_session.query(DBHierarchyNode).filter(
DBHierarchyNode.source == confluence_source
).delete()
db_session.commit()
def test_link_hierarchy_nodes_skips_non_hierarchy_sources(
db_session: Session,
) -> None:
"""link_hierarchy_nodes_to_documents should return 0 for sources that
don't support hierarchy-node-as-document (e.g. Slack, Google Drive)."""
linked = link_hierarchy_nodes_to_documents(
db_session=db_session,
document_ids=SLIM_DOC_IDS,
source=TEST_SOURCE, # Slack — not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS
commit=False,
)
assert linked == 0

View File

@@ -11,6 +11,7 @@ from onyx.context.search.models import SavedSearchSettings
from onyx.context.search.models import SearchSettingsCreationRequest
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_default_contextual_rag_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_contextual_model
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
@@ -37,6 +38,8 @@ def _create_llm_provider_and_model(
model_name: str,
) -> None:
"""Insert an LLM provider with a single visible model configuration."""
if fetch_existing_llm_provider(name=provider_name, db_session=db_session):
return
upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
@@ -146,8 +149,8 @@ def baseline_search_settings(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -155,6 +158,7 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -196,8 +200,8 @@ def test_indexing_pipeline_uses_contextual_rag_settings_from_create(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.db.swap_index.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -205,6 +209,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
mock_get_all_doc_indices: MagicMock,
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
@@ -266,7 +271,7 @@ def test_indexing_pipeline_uses_updated_contextual_rag_settings(
)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
@patch("onyx.server.manage.search_settings.get_all_document_indices")
@patch("onyx.server.manage.search_settings.get_default_document_index")
@patch("onyx.indexing.indexing_pipeline.get_llm_for_contextual_rag")
@patch("onyx.indexing.indexing_pipeline.index_doc_batch_with_handler")
@@ -274,6 +279,7 @@ def test_indexing_pipeline_skips_llm_when_contextual_rag_disabled(
mock_index_handler: MagicMock,
mock_get_llm: MagicMock,
mock_get_doc_index: MagicMock, # noqa: ARG001
mock_get_all_doc_indices_search_settings: MagicMock, # noqa: ARG001
baseline_search_settings: None, # noqa: ARG001
db_session: Session,
) -> None:

View File

@@ -114,8 +114,8 @@ def test_create_duplicate_config_fails(
headers=admin_user.headers,
)
assert response.status_code == 409
assert "already exists" in response.json()["message"]
assert response.status_code == 400
assert "already exists" in response.json()["detail"]
def test_get_all_configs(
@@ -292,7 +292,7 @@ def test_update_config_source_provider_not_found(
)
assert response.status_code == 404
assert "not found" in response.json()["message"]
assert "not found" in response.json()["detail"]
def test_delete_config(
@@ -468,7 +468,7 @@ def test_create_config_missing_credentials(
)
assert response.status_code == 400
assert "No provider or source llm provided" in response.json()["message"]
assert "No provider or source llm provided" in response.json()["detail"]
def test_create_config_source_provider_not_found(
@@ -488,4 +488,4 @@ def test_create_config_source_provider_not_found(
)
assert response.status_code == 404
assert "not found" in response.json()["message"]
assert "not found" in response.json()["detail"]

View File

@@ -42,6 +42,78 @@ class NightlyProviderConfig(BaseModel):
strict: bool
def _stringify_custom_config_value(value: object) -> str:
if isinstance(value, str):
return value
if isinstance(value, (dict, list)):
return json.dumps(value)
return str(value)
def _looks_like_vertex_credentials_payload(
raw_custom_config: dict[object, object],
) -> bool:
normalized_keys = {str(key).strip().lower() for key in raw_custom_config}
provider_specific_keys = {
"vertex_credentials",
"credentials_file",
"vertex_credentials_file",
"google_application_credentials",
"vertex_location",
"location",
"vertex_region",
"region",
}
if normalized_keys & provider_specific_keys:
return False
normalized_type = str(raw_custom_config.get("type", "")).strip().lower()
if normalized_type not in {"service_account", "external_account"}:
return False
# Service account JSON usually includes private_key/client_email, while external
# account JSON includes credential_source. Either shape should be accepted.
has_service_account_markers = any(
key in normalized_keys for key in {"private_key", "client_email"}
)
has_external_account_markers = "credential_source" in normalized_keys
return has_service_account_markers or has_external_account_markers
def _normalize_custom_config(
provider: str, raw_custom_config: dict[object, object]
) -> dict[str, str]:
if provider == "vertex_ai" and _looks_like_vertex_credentials_payload(
raw_custom_config
):
return {"vertex_credentials": json.dumps(raw_custom_config)}
normalized: dict[str, str] = {}
for raw_key, raw_value in raw_custom_config.items():
key = str(raw_key).strip()
key_lower = key.lower()
if provider == "vertex_ai":
if key_lower in {
"vertex_credentials",
"credentials_file",
"vertex_credentials_file",
"google_application_credentials",
}:
key = "vertex_credentials"
elif key_lower in {
"vertex_location",
"location",
"vertex_region",
"region",
}:
key = "vertex_location"
normalized[key] = _stringify_custom_config_value(raw_value)
return normalized
def _env_true(env_var: str, default: bool = False) -> bool:
value = os.environ.get(env_var)
if value is None:
@@ -80,7 +152,9 @@ def _load_provider_config() -> NightlyProviderConfig:
parsed = json.loads(custom_config_json)
if not isinstance(parsed, dict):
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
custom_config = {str(key): str(value) for key, value in parsed.items()}
custom_config = _normalize_custom_config(
provider=provider, raw_custom_config=parsed
)
if provider == "ollama_chat" and api_key and not custom_config:
custom_config = {"OLLAMA_API_KEY": api_key}
@@ -148,6 +222,23 @@ def _validate_provider_config(config: NightlyProviderConfig) -> None:
),
)
if config.provider == "vertex_ai":
has_vertex_credentials = bool(
config.custom_config and config.custom_config.get("vertex_credentials")
)
if not has_vertex_credentials:
configured_keys = (
sorted(config.custom_config.keys()) if config.custom_config else []
)
_skip_or_fail(
strict=config.strict,
message=(
f"{_ENV_CUSTOM_CONFIG_JSON} must include 'vertex_credentials' "
f"for provider '{config.provider}'. "
f"Found keys: {configured_keys}"
),
)
def _assert_integration_mode_enabled() -> None:
assert (
@@ -193,6 +284,7 @@ def _create_provider_payload(
return {
"name": provider_name,
"provider": provider,
"model": model_name,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
@@ -208,24 +300,23 @@ def _create_provider_payload(
}
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
def _ensure_provider_is_default(
provider_id: int, model_name: str, admin_user: DATestUser
) -> None:
list_response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
list_response.raise_for_status()
providers = list_response.json()
current_default = next(
(provider for provider in providers if provider.get("is_default_provider")),
None,
default_text = list_response.json().get("default_text")
assert default_text is not None, "Expected a default provider after setting default"
assert default_text.get("provider_id") == provider_id, (
f"Expected provider {provider_id} to be default, "
f"found {default_text.get('provider_id')}"
)
assert (
current_default is not None
), "Expected a default provider after setting provider as default"
assert (
current_default["id"] == provider_id
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
default_text.get("model_name") == model_name
), f"Expected default model {model_name}, found {default_text.get('model_name')}"
def _run_chat_assertions(
@@ -326,8 +417,9 @@ def _create_and_test_provider_for_model(
try:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
f"{API_SERVER_URL}/admin/llm/default",
headers=admin_user.headers,
json={"provider_id": provider_id, "model_name": model_name},
)
assert set_default_response.status_code == 200, (
f"Setting default provider failed for provider={config.provider} "
@@ -335,7 +427,9 @@ def _create_and_test_provider_for_model(
f"{set_default_response.text}"
)
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
_ensure_provider_is_default(
provider_id=provider_id, model_name=model_name, admin_user=admin_user
)
_run_chat_assertions(
admin_user=admin_user,
search_tool_id=search_tool_id,

View File

@@ -1,4 +1,3 @@
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -300,7 +299,7 @@ def test_update_contextual_rag_nonexistent_provider(
headers=admin_user.headers,
)
assert response.status_code == 400
assert "Provider nonexistent-provider not found" in response.json()["message"]
assert "Provider nonexistent-provider not found" in response.json()["detail"]
def test_update_contextual_rag_nonexistent_model(
@@ -322,7 +321,7 @@ def test_update_contextual_rag_nonexistent_model(
assert response.status_code == 400
assert (
f"Model nonexistent-model not found in provider {llm_provider.name}"
in response.json()["message"]
in response.json()["detail"]
)
@@ -342,7 +341,7 @@ def test_update_contextual_rag_missing_provider_name(
headers=admin_user.headers,
)
assert response.status_code == 400
assert "Provider name and model name are required" in response.json()["message"]
assert "Provider name and model name are required" in response.json()["detail"]
def test_update_contextual_rag_missing_model_name(
@@ -362,10 +361,9 @@ def test_update_contextual_rag_missing_model_name(
headers=admin_user.headers,
)
assert response.status_code == 400
assert "Provider name and model name are required" in response.json()["message"]
assert "Provider name and model name are required" in response.json()["detail"]
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_with_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -394,7 +392,6 @@ def test_set_new_search_settings_with_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_without_contextual_rag(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -419,7 +416,6 @@ def test_set_new_search_settings_without_contextual_rag(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_then_update_inference_settings(
reset: None, # noqa: ARG001
admin_user: DATestUser,
@@ -457,7 +453,6 @@ def test_set_new_then_update_inference_settings(
_cancel_new_embedding(admin_user)
@pytest.mark.skip(reason="Set new search settings is temporarily disabled.")
def test_set_new_search_settings_replaces_previous_secondary(
reset: None, # noqa: ARG001
admin_user: DATestUser,

View File

@@ -281,9 +281,10 @@ class TestApplyLicenseStatusToSettings:
}
class TestSettingsDefaultEEDisabled:
"""Verify the Settings model defaults ee_features_enabled to False."""
class TestSettingsDefaults:
"""Verify Settings model defaults for CE deployments."""
def test_default_ee_features_disabled(self) -> None:
"""CE default: ee_features_enabled is False."""
settings = Settings()
assert settings.ee_features_enabled is False

View File

@@ -104,3 +104,102 @@ def test_format_slack_message_ampersand_not_double_escaped() -> None:
assert "&amp;" in formatted
assert "&quot;" not in formatted
# -- Table rendering tests --
def test_table_renders_as_vertical_cards() -> None:
message = (
"| Feature | Status | Owner |\n"
"|---------|--------|-------|\n"
"| Auth | Done | Alice |\n"
"| Search | In Progress | Bob |\n"
)
formatted = format_slack_message(message)
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
# Cards separated by blank line
assert "Owner: Alice\n\n*Search*" in formatted
# No raw pipe-and-dash table syntax
assert "---|" not in formatted
def test_table_single_column() -> None:
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
formatted = format_slack_message(message)
assert "*Alice*" in formatted
assert "*Bob*" in formatted
def test_table_embedded_in_text() -> None:
message = (
"Here are the results:\n\n"
"| Item | Count |\n"
"|------|-------|\n"
"| Apples | 5 |\n"
"\n"
"That's all."
)
formatted = format_slack_message(message)
assert "Here are the results:" in formatted
assert "*Apples*\n • Count: 5" in formatted
assert "That's all." in formatted
def test_table_with_formatted_cells() -> None:
message = (
"| Name | Link |\n"
"|------|------|\n"
"| **Alice** | [profile](https://example.com) |\n"
)
formatted = format_slack_message(message)
# Bold cell should not double-wrap: *Alice* not **Alice**
assert "*Alice*" in formatted
assert "**Alice**" not in formatted
assert "<https://example.com|profile>" in formatted
def test_table_with_alignment_specifiers() -> None:
message = (
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
)
formatted = format_slack_message(message)
assert "*a*\n • Center: b\n • Right: c" in formatted
def test_two_tables_in_same_message_use_independent_headers() -> None:
message = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"| X | Y | Z |\n"
"|---|---|---|\n"
"| p | q | r |\n"
)
formatted = format_slack_message(message)
assert "*1*\n • B: 2" in formatted
assert "*p*\n • Y: q\n • Z: r" in formatted
def test_table_empty_first_column_no_bare_asterisks() -> None:
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
formatted = format_slack_message(message)
# Empty title should not produce "**" (bare asterisks)
assert "**" not in formatted
assert " • Status: Done" in formatted

View File

@@ -87,7 +87,8 @@ def test_python_tool_available_when_health_check_passes(
mock_client = MagicMock()
mock_client.health.return_value = True
mock_client_cls.return_value = mock_client
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is True
@@ -109,7 +110,8 @@ def test_python_tool_unavailable_when_health_check_fails(
mock_client = MagicMock()
mock_client.health.return_value = False
mock_client_cls.return_value = mock_client
mock_client_cls.return_value.__enter__ = MagicMock(return_value=mock_client)
mock_client_cls.return_value.__exit__ = MagicMock(return_value=False)
db_session = MagicMock(spec=Session)
assert PythonTool.is_available(db_session) is False

View File

@@ -138,7 +138,6 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
- MULTI_TENANT=true
- LOG_LEVEL=DEBUG

View File

@@ -52,7 +52,6 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -65,7 +65,6 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -70,7 +70,6 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=${AUTH_TYPE:-oidc}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -58,7 +58,6 @@ services:
env_file:
- .env_eval
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- AUTH_TYPE=disabled
- POSTGRES_HOST=relational_db
- VESPA_HOST=index

View File

@@ -146,7 +146,6 @@ services:
- indexing_model_server
restart: unless-stopped
environment:
- USE_LIGHTWEIGHT_BACKGROUND_WORKER=${USE_LIGHTWEIGHT_BACKGROUND_WORKER:-true}
- FILE_STORE_BACKEND=${FILE_STORE_BACKEND:-s3}
- POSTGRES_HOST=${POSTGRES_HOST:-relational_db}
- VESPA_HOST=${VESPA_HOST:-index}

View File

@@ -14,30 +14,32 @@ Built with [Tauri](https://tauri.app) for minimal bundle size (~10MB vs Electron
## Keyboard Shortcuts
| Shortcut | Action |
|----------|--------|
| `⌘ N` | New Chat |
| `⌘ ⇧ N` | New Window |
| `⌘ R` | Reload |
| `⌘ [` | Go Back |
| `⌘ ]` | Go Forward |
| `⌘ ,` | Open Config File |
| `⌘ W` | Close Window |
| `⌘ Q` | Quit |
| Shortcut | Action |
| -------- | ---------------- |
| `⌘ N` | New Chat |
| `⌘ ⇧ N` | New Window |
| `⌘ R` | Reload |
| `⌘ [` | Go Back |
| `⌘ ]` | Go Forward |
| `⌘ ,` | Open Config File |
| `⌘ W` | Close Window |
| `⌘ Q` | Quit |
## Prerequisites
1. **Rust** (latest stable)
```bash
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
source $HOME/.cargo/env
```
2. **Node.js** (18+)
```bash
# Using homebrew
brew install node
# Or using nvm
nvm install 18
```
@@ -55,16 +57,21 @@ npm install
# Run in development mode
npm run dev
# Run in debug mode
npm run debug
```
## Building
### Build for current architecture
```bash
npm run build
```
### Build Universal Binary (Intel + Apple Silicon)
```bash
# First, add the targets
rustup target add x86_64-apple-darwin
@@ -103,6 +110,7 @@ Before building, add your app icons to `src-tauri/icons/`:
- `icon.ico` (Windows, optional)
You can generate these from a 1024x1024 source image using:
```bash
# Using tauri's icon generator
npm run tauri icon path/to/your-icon.png
@@ -115,6 +123,7 @@ npm run tauri icon path/to/your-icon.png
The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
**Config file location:**
- macOS: `~/Library/Application Support/app.onyx.desktop/config.json`
- Linux: `~/.config/app.onyx.desktop/config.json`
- Windows: `%APPDATA%/app.onyx.desktop/config.json`
@@ -135,6 +144,7 @@ The app defaults to `https://cloud.onyx.app` but supports any Onyx instance.
4. Restart the app
**Quick edit via terminal:**
```bash
# macOS
open -t ~/Library/Application\ Support/app.onyx.desktop/config.json
@@ -146,6 +156,7 @@ code ~/Library/Application\ Support/app.onyx.desktop/config.json
### Change the default URL in build
Edit `src-tauri/tauri.conf.json`:
```json
{
"app": {
@@ -165,6 +176,7 @@ Edit `src-tauri/src/main.rs` in the `setup_shortcuts` function.
### Window appearance
Modify the window configuration in `src-tauri/tauri.conf.json`:
- `titleBarStyle`: `"Overlay"` (macOS native) or `"Visible"`
- `decorations`: Window chrome
- `transparent`: For custom backgrounds
@@ -172,16 +184,20 @@ Modify the window configuration in `src-tauri/tauri.conf.json`:
## Troubleshooting
### "Unable to resolve host"
Make sure you have an internet connection. The app loads content from `cloud.onyx.app`.
### Build fails on M1/M2 Mac
```bash
# Ensure you have the right target
rustup target add aarch64-apple-darwin
```
### Code signing for distribution
For distributing outside the App Store, you'll need to:
1. Get an Apple Developer certificate
2. Sign the app: `codesign --deep --force --sign "Developer ID" target/release/bundle/macos/Onyx.app`
3. Notarize with Apple

View File

@@ -4,6 +4,7 @@
"description": "Lightweight desktop app for Onyx Cloud",
"scripts": {
"dev": "tauri dev",
"debug": "tauri dev -- -- --debug",
"build": "tauri build",
"build:dmg": "tauri build --target universal-apple-darwin",
"build:linux": "tauri build --bundles deb,rpm"

View File

@@ -23,3 +23,4 @@ url = "2.5"
[features]
default = ["custom-protocol"]
custom-protocol = ["tauri/custom-protocol"]
devtools = ["tauri/devtools"]

View File

@@ -6,7 +6,9 @@ use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use std::sync::RwLock;
use std::sync::{Mutex, RwLock};
use std::io::Write as IoWrite;
use std::time::SystemTime;
#[cfg(target_os = "macos")]
use std::time::Duration;
use tauri::image::Image;
@@ -230,6 +232,63 @@ const MENU_KEY_HANDLER_SCRIPT: &str = r#"
})();
"#;
const CONSOLE_CAPTURE_SCRIPT: &str = r#"
(() => {
if (window.__ONYX_CONSOLE_CAPTURE__) return;
window.__ONYX_CONSOLE_CAPTURE__ = true;
const levels = ['log', 'warn', 'error', 'info', 'debug'];
const originals = {};
levels.forEach(level => {
originals[level] = console[level];
console[level] = function(...args) {
originals[level].apply(console, args);
try {
const invoke =
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
if (typeof invoke === 'function') {
const message = args.map(a => {
try { return typeof a === 'string' ? a : JSON.stringify(a); }
catch { return String(a); }
}).join(' ');
invoke('log_from_frontend', { level, message });
}
} catch {}
};
});
window.addEventListener('error', (event) => {
try {
const invoke =
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
if (typeof invoke === 'function') {
invoke('log_from_frontend', {
level: 'error',
message: `[uncaught] ${event.message} at ${event.filename}:${event.lineno}:${event.colno}`
});
}
} catch {}
});
window.addEventListener('unhandledrejection', (event) => {
try {
const invoke =
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
if (typeof invoke === 'function') {
invoke('log_from_frontend', {
level: 'error',
message: `[unhandled rejection] ${event.reason}`
});
}
} catch {}
});
})();
"#;
const MENU_TOGGLE_DEVTOOLS_ID: &str = "toggle_devtools";
const MENU_OPEN_DEBUG_LOG_ID: &str = "open_debug_log";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
pub server_url: String,
@@ -311,12 +370,87 @@ fn save_config(config: &AppConfig) -> Result<(), String> {
Ok(())
}
// ============================================================================
// Debug Mode
// ============================================================================
fn is_debug_mode() -> bool {
std::env::args().any(|arg| arg == "--debug") || std::env::var("ONYX_DEBUG").is_ok()
}
fn get_debug_log_path() -> Option<PathBuf> {
get_config_dir().map(|dir| dir.join("frontend_debug.log"))
}
fn init_debug_log_file() -> Option<fs::File> {
let log_path = get_debug_log_path()?;
if let Some(parent) = log_path.parent() {
let _ = fs::create_dir_all(parent);
}
fs::OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)
.ok()
}
fn format_utc_timestamp() -> String {
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default();
let total_secs = now.as_secs();
let millis = now.subsec_millis();
let days = total_secs / 86400;
let secs_of_day = total_secs % 86400;
let hours = secs_of_day / 3600;
let mins = (secs_of_day % 3600) / 60;
let secs = secs_of_day % 60;
// Days since Unix epoch -> Y/M/D via civil calendar arithmetic
let z = days as i64 + 719468;
let era = z / 146097;
let doe = z - era * 146097;
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
let y = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y + 1 } else { y };
format!(
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
y, m, d, hours, mins, secs, millis
)
}
fn inject_console_capture(webview: &Webview) {
let _ = webview.eval(CONSOLE_CAPTURE_SCRIPT);
}
fn maybe_open_devtools(app: &AppHandle, window: &tauri::WebviewWindow) {
#[cfg(any(debug_assertions, feature = "devtools"))]
{
let state = app.state::<ConfigState>();
if state.debug_mode {
window.open_devtools();
}
}
#[cfg(not(any(debug_assertions, feature = "devtools")))]
{
let _ = (app, window);
}
}
// Global config state
struct ConfigState {
config: RwLock<AppConfig>,
config_initialized: RwLock<bool>,
app_base_url: RwLock<Option<Url>>,
menu_temporarily_visible: RwLock<bool>,
debug_mode: bool,
debug_log_file: Mutex<Option<fs::File>>,
}
fn focus_main_window(app: &AppHandle) {
@@ -372,6 +506,7 @@ fn trigger_new_window(app: &AppHandle) {
}
apply_settings_to_window(&handle, &window);
maybe_open_devtools(&handle, &window);
let _ = window.set_focus();
}
});
@@ -467,10 +602,65 @@ fn inject_chat_link_intercept(webview: &Webview) {
let _ = webview.eval(CHAT_LINK_INTERCEPT_SCRIPT);
}
fn handle_toggle_devtools(app: &AppHandle) {
#[cfg(any(debug_assertions, feature = "devtools"))]
{
let windows: Vec<_> = app.webview_windows().into_values().collect();
let any_open = windows.iter().any(|w| w.is_devtools_open());
for window in &windows {
if any_open {
window.close_devtools();
} else {
window.open_devtools();
}
}
}
#[cfg(not(any(debug_assertions, feature = "devtools")))]
{
let _ = app;
}
}
fn handle_open_debug_log() {
let log_path = match get_debug_log_path() {
Some(p) => p,
None => return,
};
if !log_path.exists() {
eprintln!("[ONYX DEBUG] Log file does not exist yet: {:?}", log_path);
return;
}
let url_path = log_path.to_string_lossy().replace('\\', "/");
let _ = open_in_default_browser(&format!(
"file:///{}",
url_path.trim_start_matches('/')
));
}
// ============================================================================
// Tauri Commands
// ============================================================================
#[tauri::command]
fn log_from_frontend(level: String, message: String, state: tauri::State<ConfigState>) {
if !state.debug_mode {
return;
}
let timestamp = format_utc_timestamp();
let log_line = format!("[{}] [{}] {}", timestamp, level.to_uppercase(), message);
eprintln!("{}", log_line);
if let Ok(mut guard) = state.debug_log_file.lock() {
if let Some(ref mut file) = *guard {
let _ = writeln!(file, "{}", log_line);
let _ = file.flush();
}
}
}
/// Get the current server URL
#[tauri::command]
fn get_server_url(state: tauri::State<ConfigState>) -> String {
@@ -657,6 +847,7 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
}
apply_settings_to_window(&app, &window);
maybe_open_devtools(&app, &window);
Ok(())
}
@@ -936,6 +1127,30 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
menu.append(&help_menu)?;
}
let state = app.state::<ConfigState>();
if state.debug_mode {
let toggle_devtools_item = MenuItem::with_id(
app,
MENU_TOGGLE_DEVTOOLS_ID,
"Toggle DevTools",
true,
Some("F12"),
)?;
let open_log_item = MenuItem::with_id(
app,
MENU_OPEN_DEBUG_LOG_ID,
"Open Debug Log",
true,
None::<&str>,
)?;
let debug_menu = SubmenuBuilder::new(app, "Debug")
.item(&toggle_devtools_item)
.item(&open_log_item)
.build()?;
menu.append(&debug_menu)?;
}
app.set_menu(menu)?;
Ok(())
}
@@ -1027,8 +1242,20 @@ fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> {
// ============================================================================
fn main() {
// Load config at startup
let (config, config_initialized) = load_config();
let debug_mode = is_debug_mode();
let debug_log_file = if debug_mode {
eprintln!("[ONYX DEBUG] Debug mode enabled");
if let Some(path) = get_debug_log_path() {
eprintln!("[ONYX DEBUG] Frontend logs: {}", path.display());
}
eprintln!("[ONYX DEBUG] DevTools will open automatically");
eprintln!("[ONYX DEBUG] Capturing console.log/warn/error/info/debug from webview");
init_debug_log_file()
} else {
None
};
tauri::Builder::default()
.plugin(tauri_plugin_shell::init())
@@ -1059,6 +1286,8 @@ fn main() {
config_initialized: RwLock::new(config_initialized),
app_base_url: RwLock::new(None),
menu_temporarily_visible: RwLock::new(false),
debug_mode,
debug_log_file: Mutex::new(debug_log_file),
})
.invoke_handler(tauri::generate_handler![
get_server_url,
@@ -1077,7 +1306,8 @@ fn main() {
start_drag_window,
toggle_menu_bar,
show_menu_bar_temporarily,
hide_menu_bar_temporary
hide_menu_bar_temporary,
log_from_frontend
])
.on_menu_event(|app, event| match event.id().as_ref() {
"open_docs" => open_docs(),
@@ -1086,6 +1316,8 @@ fn main() {
"open_settings" => open_settings(app),
"show_menu_bar" => handle_menu_bar_toggle(app),
"hide_window_decorations" => handle_decorations_toggle(app),
MENU_TOGGLE_DEVTOOLS_ID => handle_toggle_devtools(app),
MENU_OPEN_DEBUG_LOG_ID => handle_open_debug_log(),
_ => {}
})
.setup(move |app| {
@@ -1119,6 +1351,7 @@ fn main() {
inject_titlebar(window.clone());
apply_settings_to_window(&app_handle, &window);
maybe_open_devtools(&app_handle, &window);
let _ = window.set_focus();
}
@@ -1128,6 +1361,14 @@ fn main() {
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
inject_chat_link_intercept(webview);
{
let app = webview.app_handle();
let state = app.state::<ConfigState>();
if state.debug_mode {
inject_console_capture(webview);
}
}
#[cfg(not(target_os = "macos"))]
{
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);

6
uv.lock generated
View File

@@ -4106,7 +4106,7 @@ wheels = [
[[package]]
name = "nltk"
version = "3.9.1"
version = "3.9.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
@@ -4114,9 +4114,9 @@ dependencies = [
{ name = "regex" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" }
sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442, upload-time = "2024-08-18T19:48:21.909Z" },
{ url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" },
]
[[package]]

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgAudio = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 32 32"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M4 20V12M10 28V4M22 22V10M28 18V14M16 20V12"
strokeWidth={2.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgAudio;

View File

@@ -0,0 +1,22 @@
import type { IconProps } from "@opal/types";
const SvgColumn = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M6 14H3.33333C2.59695 14 2 13.403 2 12.6667V3.33333C2 2.59695 2.59695 2 3.33333 2H6M6 14V2M6 14H10M6 2H10M10 2H12.6667C13.403 2 14 2.59695 14 3.33333V12.6667C14 13.403 13.403 14 12.6667 14H10M10 2V14"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgColumn;

View File

@@ -0,0 +1,20 @@
import type { IconProps } from "@opal/types";
const SvgHandle = ({ size = 16, ...props }: IconProps) => (
<svg
width={Math.round((size * 3) / 17)}
height={size}
viewBox="0 0 3 17"
fill="none"
xmlns="http://www.w3.org/2000/svg"
{...props}
>
<path
d="M0.5 0.5V16.5M2.5 0.5V16.5"
stroke="currentColor"
strokeLinecap="round"
/>
</svg>
);
export default SvgHandle;

View File

@@ -17,6 +17,7 @@ export { default as SvgArrowUpDown } from "@opal/icons/arrow-up-down";
export { default as SvgArrowUpDot } from "@opal/icons/arrow-up-dot";
export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
export { default as SvgAudio } from "@opal/icons/audio";
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
export { default as SvgAws } from "@opal/icons/aws";
export { default as SvgAzure } from "@opal/icons/azure";
@@ -49,6 +50,7 @@ export { default as SvgClock } from "@opal/icons/clock";
export { default as SvgClockHandsSmall } from "@opal/icons/clock-hands-small";
export { default as SvgCloud } from "@opal/icons/cloud";
export { default as SvgCode } from "@opal/icons/code";
export { default as SvgColumn } from "@opal/icons/column";
export { default as SvgCopy } from "@opal/icons/copy";
export { default as SvgCornerRightUpDot } from "@opal/icons/corner-right-up-dot";
export { default as SvgCpu } from "@opal/icons/cpu";
@@ -79,6 +81,7 @@ export { default as SvgFolderPartialOpen } from "@opal/icons/folder-partial-open
export { default as SvgFolderPlus } from "@opal/icons/folder-plus";
export { default as SvgGemini } from "@opal/icons/gemini";
export { default as SvgGlobe } from "@opal/icons/globe";
export { default as SvgHandle } from "@opal/icons/handle";
export { default as SvgHardDrive } from "@opal/icons/hard-drive";
export { default as SvgHashSmall } from "@opal/icons/hash-small";
export { default as SvgHash } from "@opal/icons/hash";
@@ -103,6 +106,8 @@ export { default as SvgLogOut } from "@opal/icons/log-out";
export { default as SvgMaximize2 } from "@opal/icons/maximize-2";
export { default as SvgMcp } from "@opal/icons/mcp";
export { default as SvgMenu } from "@opal/icons/menu";
export { default as SvgMicrophone } from "@opal/icons/microphone";
export { default as SvgMicrophoneOff } from "@opal/icons/microphone-off";
export { default as SvgMinus } from "@opal/icons/minus";
export { default as SvgMinusCircle } from "@opal/icons/minus-circle";
export { default as SvgMoon } from "@opal/icons/moon";
@@ -146,6 +151,8 @@ export { default as SvgSlack } from "@opal/icons/slack";
export { default as SvgSlash } from "@opal/icons/slash";
export { default as SvgSliders } from "@opal/icons/sliders";
export { default as SvgSlidersSmall } from "@opal/icons/sliders-small";
export { default as SvgSort } from "@opal/icons/sort";
export { default as SvgSortOrder } from "@opal/icons/sort-order";
export { default as SvgSparkle } from "@opal/icons/sparkle";
export { default as SvgStar } from "@opal/icons/star";
export { default as SvgStep1 } from "@opal/icons/step1";
@@ -169,7 +176,10 @@ export { default as SvgUploadCloud } from "@opal/icons/upload-cloud";
export { default as SvgUser } from "@opal/icons/user";
export { default as SvgUserManage } from "@opal/icons/user-manage";
export { default as SvgUserPlus } from "@opal/icons/user-plus";
export { default as SvgUserSync } from "@opal/icons/user-sync";
export { default as SvgUsers } from "@opal/icons/users";
export { default as SvgVolume } from "@opal/icons/volume";
export { default as SvgVolumeOff } from "@opal/icons/volume-off";
export { default as SvgWallet } from "@opal/icons/wallet";
export { default as SvgWorkflow } from "@opal/icons/workflow";
export { default as SvgX } from "@opal/icons/x";

View File

@@ -0,0 +1,29 @@
import type { IconProps } from "@opal/types";
const SvgMicrophoneOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
{/* Microphone body */}
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
{/* Diagonal slash */}
<path
d="M2 2L14 14"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophoneOff;

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgMicrophone = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M12.5 7V7.5C12.5 9.98528 10.4853 12 8 12M3.5 7V7.5C3.5 9.98528 5.51472 12 8 12M8 12V14.5M8 14.5H5M8 14.5H11M8 9.5C6.89543 9.5 6 8.60457 6 7.5V3.5C6 2.39543 6.89543 1.5 8 1.5C9.10457 1.5 10 2.39543 10 3.5V7.5C10 8.60457 9.10457 9.5 8 9.5Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgMicrophone;

View File

@@ -0,0 +1,21 @@
import type { IconProps } from "@opal/types";
const SvgSortOrder = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2.66675 12L7.67009 12.0001M2.66675 8H10.5001M2.66675 4H13.3334"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgSortOrder;

View File

@@ -0,0 +1,27 @@
import type { IconProps } from "@opal/types";
const SvgSort = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 4.5H10M2 8H7M2 11.5H5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M12 5V12M12 12L14 10M12 12L10 10"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgSort;

View File

@@ -0,0 +1,22 @@
import type { IconProps } from "@opal/types";
const SvgUserSync = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M1 14C1 13.6667 1 13.3333 1 13C1 11.3431 2.34316 10 4.00002 10H7M11 8.5L9.5 10L14.5 9.99985M13 14L14.5 12.5L9.5 12.5M8.75 4.75C8.75 6.26878 7.51878 7.5 6 7.5C4.48122 7.5 3.25 6.26878 3.25 4.75C3.25 3.23122 4.48122 2 6 2C7.51878 2 8.75 3.23122 8.75 4.75Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgUserSync;

View File

@@ -0,0 +1,26 @@
import type { IconProps } from "@opal/types";
const SvgVolumeOff = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 6V10H5L9 13V3L5 6H2Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M14 6L11 9M11 6L14 9"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgVolumeOff;

View File

@@ -0,0 +1,26 @@
import type { IconProps } from "@opal/types";
const SvgVolume = ({ size, ...props }: IconProps) => (
<svg
width={size}
height={size}
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
stroke="currentColor"
{...props}
>
<path
d="M2 6V10H5L9 13V3L5 6H2Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M11.5 5.5C12.3 6.3 12.8 7.4 12.8 8.5C12.8 9.6 12.3 10.7 11.5 11.5"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</svg>
);
export default SvgVolume;

View File

@@ -58,7 +58,7 @@ const nextConfig = {
{
key: "Permissions-Policy",
value:
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
"accelerometer=(), ambient-light-sensor=(), autoplay=(), battery=(), camera=(), cross-origin-isolated=(), display-capture=(), document-domain=(), encrypted-media=(), execution-while-not-rendered=(), execution-while-out-of-viewport=(), fullscreen=(), geolocation=(), gyroscope=(), keyboard-map=(), magnetometer=(), microphone=(self), midi=(), navigation-override=(), payment=(), picture-in-picture=(), publickey-credentials-get=(), screen-wake-lock=(), sync-xhr=(), usb=(), web-share=(), xr-spatial-tracking=()",
},
],
},

8
web/package-lock.json generated
View File

@@ -105,6 +105,7 @@
"@types/node": "18.15.11",
"@types/react": "19.2.10",
"@types/react-dom": "19.2.3",
"@types/sbd": "^1.0.5",
"@types/stats.js": "^0.17.4",
"@types/uuid": "^9.0.8",
"@typescript/native-preview": "7.0.0-dev.20251222.1",
@@ -5543,6 +5544,13 @@
"integrity": "sha512-z1HGKcYy2xA8AGQfwrn0PAy+PB7X/GSj3UVJW9qKyn43xWa+gl5nXmU4qqLMRzWVLFC8KusUX8T/0kCiOYpAIQ==",
"license": "MIT"
},
"node_modules/@types/sbd": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@types/sbd/-/sbd-1.0.5.tgz",
"integrity": "sha512-60PxBBWhg0C3yb5bTP+wwWYGTKMcuB0S6mTEa1sedMC79tYY0Ei7YjU4qsWzGn++lWscLQde16SnElJrf5/aTw==",
"dev": true,
"license": "MIT"
},
"node_modules/@types/stack-utils": {
"version": "2.0.3",
"dev": true,

View File

@@ -121,6 +121,7 @@
"@types/node": "18.15.11",
"@types/react": "19.2.10",
"@types/react-dom": "19.2.3",
"@types/sbd": "^1.0.5",
"@types/stats.js": "^0.17.4",
"@types/uuid": "^9.0.8",
"@typescript/native-preview": "7.0.0-dev.20251222.1",

View File

@@ -41,7 +41,7 @@ export default defineConfig({
viewport: { width: 1280, height: 720 },
storageState: "admin_auth.json",
},
grepInvert: /@exclusive/,
grepInvert: [/@exclusive/, /@lite/],
},
{
// this suite runs independently and serially + slower
@@ -55,5 +55,15 @@ export default defineConfig({
grep: /@exclusive/,
workers: 1,
},
{
// runs against the Onyx Lite stack (DISABLE_VECTOR_DB=true, no Vespa/Redis)
name: "lite",
use: {
...devices["Desktop Chrome"],
viewport: { width: 1280, height: 720 },
storageState: "admin_auth.json",
},
grep: /@lite/,
},
],
});

View File

@@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.5 2H13V14H10.5V2Z" fill="black"/>
<path d="M3 2H5.5V14H3V2Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 192 B

Some files were not shown because too many files have changed in this diff Show More