mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-10 10:12:40 +00:00
Compare commits
6 Commits
voice-mode
...
nikg/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f3c3e29fb7 | ||
|
|
eaa1e94bed | ||
|
|
a5bb2d8130 | ||
|
|
dbbca1c4e5 | ||
|
|
ff92928b31 | ||
|
|
f452a777cc |
@@ -106,34 +106,13 @@ onyx-cli ask --json "What authentication methods do we support?"
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
| `citation_info` | Source citation with `citation_number` and `document_id` |
|
||||
|
||||
### Specify an agent
|
||||
|
||||
@@ -150,10 +129,6 @@ Uses a specific Onyx agent/persona instead of the default.
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
2
.github/workflows/deployment.yml
vendored
2
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -468,7 +468,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -707,7 +707,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
2
.github/workflows/pr-quality-checks.yml
vendored
2
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
|
||||
178
.github/workflows/release-cli.yml
vendored
178
.github/workflows/release-cli.yml
vendored
@@ -26,7 +26,8 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -37,178 +38,3 @@ jobs:
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -22,11 +22,13 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -544,8 +544,6 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
|
||||
@@ -141,7 +141,6 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -21,52 +22,59 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
if MULTI_TENANT:
|
||||
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -68,7 +68,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -80,11 +79,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -123,26 +117,6 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -14,91 +14,67 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
return encoded_key
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_str.encode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_bytes.decode()
|
||||
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
return decrypted_data.decode()
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
return versioned_encryption_fn(input_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -28,7 +28,6 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -121,7 +120,6 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -1601,102 +1599,6 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -55,7 +54,6 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -1011,7 +1009,6 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1218,14 +1215,7 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,10 +68,6 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
@@ -205,7 +204,7 @@ def _manage_async_retrieval(
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncGenerator[Document, None]:
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
@@ -228,23 +227,22 @@ def _manage_async_retrieval(
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
async_gen = _async_fetch()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
doc = loop.run_until_complete(anext(async_gen))
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
# Must close the async generator before the loop so the Discord
|
||||
# client's `async with` block can await its shutdown coroutine.
|
||||
# The nested try/finally ensures the loop always closes even if
|
||||
# aclose() raises (same pattern as cursor.close() before conn.close()).
|
||||
try:
|
||||
loop.run_until_complete(async_gen.aclose())
|
||||
finally:
|
||||
loop.close()
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
@@ -1722,7 +1722,6 @@ class GoogleDriveConnector(
|
||||
primary_admin_email=self.primary_admin_email,
|
||||
google_domain=self.google_domain,
|
||||
),
|
||||
retriever_email=file.user_email,
|
||||
):
|
||||
slim_batch.append(doc)
|
||||
|
||||
|
||||
@@ -476,7 +476,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -485,8 +484,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
add_prefix: When True, prefix group IDs with source type (for indexing path).
|
||||
When False (default), leave unprefixed (for permission sync path
|
||||
where upsert_document_external_perms handles prefixing).
|
||||
fallback_user_email: When permission info can't be retrieved (e.g. externally-owned
|
||||
files), fall back to granting access to this user.
|
||||
"""
|
||||
external_access_fn = cast(
|
||||
Callable[
|
||||
@@ -495,7 +492,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
str,
|
||||
GoogleDriveService | None,
|
||||
GoogleDriveService,
|
||||
str,
|
||||
bool,
|
||||
],
|
||||
ExternalAccess,
|
||||
@@ -511,7 +507,6 @@ def _get_external_access_for_raw_gdrive_file(
|
||||
company_domain,
|
||||
retriever_drive_service,
|
||||
admin_drive_service,
|
||||
fallback_user_email,
|
||||
add_prefix,
|
||||
)
|
||||
|
||||
@@ -677,7 +672,6 @@ def _convert_drive_item_to_document(
|
||||
creds, user_email=permission_sync_context.primary_admin_email
|
||||
),
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
@@ -759,7 +753,6 @@ def build_slim_document(
|
||||
# if not specified, we will not sync permissions
|
||||
# will also be a no-op if EE is not enabled
|
||||
permission_sync_context: PermissionSyncContext | None,
|
||||
retriever_email: str,
|
||||
) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
@@ -781,7 +774,6 @@ def build_slim_document(
|
||||
creds,
|
||||
user_email=permission_sync_context.primary_admin_email,
|
||||
),
|
||||
fallback_user_email=retriever_email,
|
||||
)
|
||||
if permission_sync_context
|
||||
else None
|
||||
|
||||
@@ -36,11 +36,9 @@ from sqlalchemy import Text
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import Mapper
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import LargeBinary
|
||||
@@ -119,50 +117,10 @@ class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class _EncryptedBase(TypeDecorator):
|
||||
"""Base for encrypted column types that wrap values in SensitiveValue."""
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def wrap_raw(self, value: Any) -> SensitiveValue:
|
||||
"""Encrypt a raw value and wrap it in SensitiveValue.
|
||||
|
||||
Called by the attribute set event so the Python-side type is always
|
||||
SensitiveValue, regardless of whether the value was loaded from the DB
|
||||
or assigned in application code.
|
||||
"""
|
||||
if self._is_json:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(
|
||||
f"EncryptedJson column expected dict, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = json.dumps(value)
|
||||
else:
|
||||
if not isinstance(value, str):
|
||||
raise TypeError(
|
||||
f"EncryptedString column expected str, got {type(value).__name__}"
|
||||
)
|
||||
raw_str = value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=encrypt_string_to_bytes(raw_str),
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=self._is_json,
|
||||
)
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
@@ -186,9 +144,20 @@ class EncryptedString(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
_is_json: bool = True
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
@@ -196,7 +165,9 @@ class EncryptedJson(_EncryptedBase):
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
@@ -213,40 +184,14 @@ class EncryptedJson(_EncryptedBase):
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
_REGISTERED_ATTRS: set[str] = set()
|
||||
|
||||
|
||||
@event.listens_for(Mapper, "mapper_configured")
|
||||
def _register_sensitive_value_set_events(
|
||||
mapper: Mapper,
|
||||
class_: type,
|
||||
) -> None:
|
||||
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, _EncryptedBase):
|
||||
col_type = col.type
|
||||
attr = getattr(class_, prop.key)
|
||||
|
||||
# Guard against double-registration (e.g. if mapper is
|
||||
# re-configured in test setups)
|
||||
attr_key = f"{class_.__qualname__}.{prop.key}"
|
||||
if attr_key in _REGISTERED_ATTRS:
|
||||
continue
|
||||
_REGISTERED_ATTRS.add(attr_key)
|
||||
|
||||
@event.listens_for(attr, "set", retval=True)
|
||||
def _wrap_value(
|
||||
target: Any, # noqa: ARG001
|
||||
value: Any,
|
||||
oldvalue: Any, # noqa: ARG001
|
||||
initiator: Any, # noqa: ARG001
|
||||
_col_type: _EncryptedBase = col_type,
|
||||
) -> Any:
|
||||
if value is not None and not isinstance(value, SensitiveValue):
|
||||
return _col_type.wrap_raw(value)
|
||||
return value
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -339,11 +284,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -3024,63 +2964,6 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -1,220 +0,0 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
|
||||
MIN_VOICE_PLAYBACK_SPEED = 0.5
|
||||
MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider).where(VoiceProvider.provider_type == provider_type)
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
db_session.delete(provider)
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise ValueError(f"No voice provider with id {provider_id} exists.")
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For all fields, None means "don't update this field".
|
||||
"""
|
||||
values: dict[str, bool | float] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
|
||||
)
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -1,103 +0,0 @@
|
||||
# Vector DB Filter Semantics
|
||||
|
||||
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
|
||||
|
||||
## Filter categories
|
||||
|
||||
| Category | Fields | Join logic |
|
||||
|---|---|---|
|
||||
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
|
||||
|
||||
```
|
||||
NOT hidden
|
||||
AND tenant = T -- if multi-tenant
|
||||
AND (acl contains A1 OR acl contains A2)
|
||||
AND (source_type = S1 OR ...) -- if set
|
||||
AND (tag = T1 OR ...) -- if set
|
||||
AND <knowledge scope> -- see below
|
||||
AND time >= cutoff -- if set
|
||||
```
|
||||
|
||||
## Knowledge scope rules
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
```
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
|
||||
-- User files + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
```
|
||||
-- Just ACL, no restriction
|
||||
NOT hidden
|
||||
AND (acl contains ...)
|
||||
```
|
||||
|
||||
## Field reference
|
||||
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
|
||||
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |
|
||||
@@ -698,6 +698,41 @@ class DocumentQuery:
|
||||
"""
|
||||
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
|
||||
|
||||
def _get_assistant_knowledge_filter(
|
||||
attached_doc_ids: list[str] | None,
|
||||
node_ids: list[int] | None,
|
||||
file_ids: list[UUID] | None,
|
||||
document_sets: list[str] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Combined filter for assistant knowledge.
|
||||
|
||||
When an assistant has attached knowledge, search should be scoped to:
|
||||
- Documents explicitly attached (by document ID), OR
|
||||
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
|
||||
- User-uploaded files attached to the assistant, OR
|
||||
- Documents in the assistant's document sets (if any)
|
||||
"""
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_doc_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_doc_ids)
|
||||
)
|
||||
if node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(node_ids)
|
||||
)
|
||||
if file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
return knowledge_filter
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
@@ -723,53 +758,41 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
|
||||
has_assistant_knowledge = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
knowledge_filter: dict[str, Any] = {
|
||||
"bool": {"should": [], "minimum_should_match": 1}
|
||||
}
|
||||
if attached_document_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_attached_document_id_filter(attached_document_ids)
|
||||
if has_assistant_knowledge:
|
||||
# If assistant has attached knowledge, scope search to that knowledge.
|
||||
# Document sets are included in the OR filter so directly attached
|
||||
# docs are always findable even if not in the document sets.
|
||||
filter_clauses.append(
|
||||
_get_assistant_knowledge_filter(
|
||||
attached_document_ids,
|
||||
hierarchy_node_ids,
|
||||
user_file_ids,
|
||||
document_sets,
|
||||
)
|
||||
if hierarchy_node_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
)
|
||||
if persona_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
)
|
||||
elif user_file_ids:
|
||||
# Fallback for non-assistant user file searches (e.g., project searches)
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
|
||||
@@ -23,8 +23,11 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_tenant_id_filter(tenant_id: str) -> str:
|
||||
return f'({TENANT_ID} contains "{tenant_id}")'
|
||||
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
|
||||
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
|
||||
if include_trailing_and:
|
||||
filter_str += " and "
|
||||
return filter_str
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
@@ -34,22 +37,30 @@ def build_vespa_filters(
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
|
||||
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
|
||||
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
|
||||
if not eq_elems:
|
||||
return ""
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause}) and "
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
"""
|
||||
For an integer field filter.
|
||||
If vals is not None, we want *only* docs whose key matches one of vals.
|
||||
"""
|
||||
# If `vals` is None => skip the filter entirely
|
||||
if vals is None or not vals:
|
||||
return ""
|
||||
|
||||
# Otherwise build the OR filter
|
||||
eq_elems = [f"{key} = {val}" for val in vals]
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
|
||||
def _build_kg_filter(
|
||||
kg_entities: list[str] | None,
|
||||
@@ -62,12 +73,16 @@ def build_vespa_filters(
|
||||
combined_filter_parts = []
|
||||
|
||||
def _build_kge(entity: str) -> str:
|
||||
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
|
||||
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
|
||||
# TYPE::* -> ({prefix: true}"TYPE")
|
||||
GENERAL = "::*"
|
||||
if entity.endswith(GENERAL):
|
||||
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
|
||||
else:
|
||||
return f'"{entity}"'
|
||||
|
||||
# OR the entities (give new design)
|
||||
if kg_entities:
|
||||
filter_parts = []
|
||||
for kg_entity in kg_entities:
|
||||
@@ -89,7 +104,8 @@ def build_vespa_filters(
|
||||
|
||||
# TODO: remove kg terms entirely from prompts and codebase
|
||||
|
||||
return f"({' and '.join(combined_filter_parts)})"
|
||||
# AND the combined filter parts
|
||||
return f"({' and '.join(combined_filter_parts)}) and "
|
||||
|
||||
def _build_kg_source_filters(
|
||||
kg_sources: list[str] | None,
|
||||
@@ -98,14 +114,16 @@ def build_vespa_filters(
|
||||
return ""
|
||||
|
||||
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
|
||||
return f"({' or '.join(source_phrases)})"
|
||||
|
||||
return f"({' or '.join(source_phrases)}) and "
|
||||
|
||||
def _build_kg_chunk_id_zero_only_filter(
|
||||
kg_chunk_id_zero_only: bool,
|
||||
) -> str:
|
||||
if not kg_chunk_id_zero_only:
|
||||
return ""
|
||||
return "(chunk_id = 0)"
|
||||
|
||||
return "(chunk_id = 0 ) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
@@ -117,8 +135,8 @@ def build_vespa_filters(
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
|
||||
|
||||
def _build_user_project_filter(
|
||||
project_id: int | None,
|
||||
@@ -129,7 +147,8 @@ def build_vespa_filters(
|
||||
pid = int(project_id)
|
||||
except Exception:
|
||||
return ""
|
||||
return f'({USER_PROJECT} contains "{pid}")'
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
@@ -141,94 +160,73 @@ def build_vespa_filters(
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}")'
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
def _append(parts: list[str], clause: str) -> None:
|
||||
if clause:
|
||||
parts.append(clause)
|
||||
|
||||
# Collect all top-level filter clauses, then join with " and " at the end.
|
||||
filter_parts: list[str] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_parts.append(f"!({HIDDEN}=true)")
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
|
||||
# If running in multi-tenant mode
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
filter_str += build_tenant_id_filter(
|
||||
filters.tenant_id, include_trailing_and=True
|
||||
)
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
|
||||
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
|
||||
# Tag filters
|
||||
tag_attributes = None
|
||||
if filters.tags:
|
||||
# build e.g. "tag_key|tag_value"
|
||||
tag_attributes = [
|
||||
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
# Document sets
|
||||
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
|
||||
# Convert UUIDs to strings for user_file_ids
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
elif len(knowledge_scope_parts) == 1:
|
||||
filter_parts.append(knowledge_scope_parts[0])
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
_append(filter_parts, _build_time_filter(filters.time_cutoff))
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
# # Knowledge Graph Filters
|
||||
# _append(filter_parts, _build_kg_filter(
|
||||
# filter_str += _build_kg_filter(
|
||||
# kg_entities=filters.kg_entities,
|
||||
# kg_relationships=filters.kg_relationships,
|
||||
# kg_terms=filters.kg_terms,
|
||||
# ))
|
||||
# )
|
||||
|
||||
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
|
||||
# filter_str += _build_kg_source_filters(filters.kg_sources)
|
||||
|
||||
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
|
||||
# filter_str += _build_kg_chunk_id_zero_only_filter(
|
||||
# filters.kg_chunk_id_zero_only or False
|
||||
# ))
|
||||
# )
|
||||
|
||||
filter_str = " and ".join(filter_parts)
|
||||
|
||||
if filter_str and not remove_trailing_and:
|
||||
filter_str += " and "
|
||||
# Trim trailing " and "
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
@@ -66,11 +66,6 @@ class OnyxErrorCode(Enum):
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Payload (413)
|
||||
# ------------------------------------------------------------------
|
||||
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -119,9 +119,6 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -500,9 +497,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -27,14 +26,6 @@ async def search_indexed_documents(
|
||||
Use this tool for information that is not public knowledge and specific to the user,
|
||||
their team, their work, or their organization/company.
|
||||
|
||||
Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM
|
||||
on every call, consuming tokens and adding latency.
|
||||
Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk,
|
||||
but this should still be sufficient for most use cases. CE mode functionality should be swapped
|
||||
when a dedicated CE search endpoint is implemented.
|
||||
|
||||
In EE mode, the dedicated search endpoint is used instead.
|
||||
|
||||
To find a list of available sources, use the `indexed_sources` resource.
|
||||
Returns chunks of text as search results with snippets, scores, and metadata.
|
||||
|
||||
@@ -120,72 +111,47 @@ async def search_indexed_documents(
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
is_ee = global_version.is_ee_version()
|
||||
base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True)
|
||||
auth_headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
|
||||
search_request: dict[str, Any]
|
||||
if is_ee:
|
||||
# EE: use the dedicated search endpoint (no LLM invocation)
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
endpoint = f"{base_url}/search/send-search-message"
|
||||
error_key = "error"
|
||||
docs_key = "search_docs"
|
||||
content_field = "content"
|
||||
else:
|
||||
# CE: fall back to the chat endpoint (invokes LLM, consumes tokens)
|
||||
search_request = {
|
||||
"message": query,
|
||||
"stream": False,
|
||||
"chat_session_info": {},
|
||||
}
|
||||
if filters:
|
||||
search_request["internal_search_filters"] = filters
|
||||
endpoint = f"{base_url}/chat/send-chat-message"
|
||||
error_key = "error_msg"
|
||||
docs_key = "top_documents"
|
||||
content_field = "blurb"
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
endpoint,
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get(error_key):
|
||||
if result.get("error"):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get(error_key),
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
documents = [
|
||||
{
|
||||
"semantic_identifier": doc.get("semantic_identifier"),
|
||||
"content": doc.get(content_field),
|
||||
"source_type": doc.get("source_type"),
|
||||
"link": doc.get("link"),
|
||||
"score": doc.get("score"),
|
||||
}
|
||||
for doc in result.get(docs_key, [])
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
"content",
|
||||
"source_type",
|
||||
"link",
|
||||
"score",
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("search_docs", [])
|
||||
]
|
||||
|
||||
# NOTE: search depth is controlled by the backend persona defaults, not `limit`.
|
||||
# `limit` only caps the returned list; fewer results may be returned if the
|
||||
# backend retrieves fewer documents than requested.
|
||||
documents = documents[:limit]
|
||||
|
||||
logger.info(
|
||||
f"Onyx MCP Server: Internal search returned {len(documents)} results"
|
||||
@@ -194,6 +160,7 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
|
||||
@@ -419,15 +419,12 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -442,96 +439,12 @@ async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
# Rate limit: max tokens per user per window
|
||||
WS_TOKEN_RATE_LIMIT_MAX = 10
|
||||
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
|
||||
|
||||
|
||||
class WsTokenRateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds the WS token generation rate limit."""
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
|
||||
Raises:
|
||||
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# Check rate limit: count tokens generated by this user in the current window
|
||||
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
|
||||
current_count = await redis.get(rate_limit_key)
|
||||
|
||||
if current_count and int(current_count) >= WS_TOKEN_RATE_LIMIT_MAX:
|
||||
logger.warning(f"WS token rate limit exceeded for user {user_id}")
|
||||
raise WsTokenRateLimitExceeded(
|
||||
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
|
||||
# Increment the counter (or create with TTL if new)
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_limit_key)
|
||||
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
|
||||
await pipe.execute()
|
||||
|
||||
# Store the actual token
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -130,7 +129,6 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -37,38 +35,6 @@ def get_safe_filename(upload: UploadFile) -> str:
|
||||
return upload.filename
|
||||
|
||||
|
||||
def get_upload_size_bytes(upload: UploadFile) -> int | None:
|
||||
"""Best-effort file size in bytes without consuming the stream."""
|
||||
if upload.size is not None:
|
||||
return upload.size
|
||||
|
||||
try:
|
||||
current_pos = upload.file.tell()
|
||||
upload.file.seek(0, 2)
|
||||
size = upload.file.tell()
|
||||
upload.file.seek(current_pos)
|
||||
return size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not determine upload size via stream seek "
|
||||
f"(filename='{get_safe_filename(upload)}', "
|
||||
f"error_type={type(e).__name__}, error={e})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
|
||||
"""Return True when upload size is known and exceeds max_bytes."""
|
||||
size_bytes = get_upload_size_bytes(upload)
|
||||
if size_bytes is None:
|
||||
logger.warning(
|
||||
"Could not determine upload size; skipping size-limit check for "
|
||||
f"'{get_safe_filename(upload)}'"
|
||||
)
|
||||
return False
|
||||
return size_bytes > max_bytes
|
||||
|
||||
|
||||
# Guard against extremely large images
|
||||
Image.MAX_IMAGE_PIXELS = 12000 * 12000
|
||||
|
||||
@@ -193,18 +159,6 @@ def categorize_uploaded_files(
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
|
||||
@@ -85,11 +85,6 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -169,9 +164,6 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -248,12 +240,6 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -1,308 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.models import VoiceOption
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
|
||||
"""Validate and normalize provider api_base / target URI."""
|
||||
if api_base is None:
|
||||
return None
|
||||
|
||||
allow_private_network = provider_type.lower() == "azure"
|
||||
try:
|
||||
return validate_outbound_http_url(
|
||||
api_base, allow_private_network=allow_private_network
|
||||
)
|
||||
except (ValueError, SSRFException) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid target URI: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Test a voice provider connection."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
# Test the provider by getting available voices (lightweight check)
|
||||
try:
|
||||
voices = provider.get_available_voices()
|
||||
if not voices:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Provider returned no available voices.",
|
||||
)
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Voice provider connection test failed: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Connection test failed. Please verify your API key and settings.",
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
@@ -1,95 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpdateSuccess(BaseModel):
|
||||
"""Simple status response for voice provider actions."""
|
||||
|
||||
status: str = "ok"
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""Voice option returned by voice providers."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
@@ -1,231 +0,0 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
# Chunk size for streaming uploads (8KB)
|
||||
UPLOAD_READ_CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
|
||||
total += len(chunk)
|
||||
if total > MAX_AUDIO_SIZE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
audio_data = b"".join(chunks)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Transcription failed. Please try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice or provider default
|
||||
final_voice = voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
Rate limited to 10 tokens per minute per user.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
await store_ws_token(token, str(user.id))
|
||||
except WsTokenRateLimitExceeded:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
"Too many token requests. Please wait before requesting another.",
|
||||
)
|
||||
return WSTokenResponse(token=token)
|
||||
@@ -1,860 +0,0 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
# WebSocket size limits to prevent memory exhaustion attacks
|
||||
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
|
||||
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
|
||||
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
|
||||
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -41,7 +41,6 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -260,15 +259,6 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# File Reader Packets
|
||||
################################################
|
||||
@@ -389,7 +379,6 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -78,7 +78,6 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
|
||||
@@ -3,7 +3,6 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -51,7 +50,6 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
return settings
|
||||
|
||||
@@ -56,23 +56,3 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,7 +92,3 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -376,8 +376,3 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -11,20 +11,16 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT encrypt called with explicit key — key ignored.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support decryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT decrypt called with explicit key — key ignored.")
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
@@ -90,15 +86,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(intput_str, key=key)
|
||||
return versioned_encryption_fn(intput_str)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(intput_bytes, key=key)
|
||||
return versioned_decryption_fn(intput_bytes)
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
"""
|
||||
jsonriver - A streaming JSON parser for Python
|
||||
|
||||
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
|
||||
Gives you a sequence of increasingly complete values.
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from .parse import _Parser as Parser
|
||||
from .parse import JsonObject
|
||||
from .parse import JsonValue
|
||||
|
||||
__all__ = ["Parser", "JsonValue", "JsonObject"]
|
||||
__version__ = "0.0.1"
|
||||
@@ -1,427 +0,0 @@
|
||||
"""
|
||||
JSON parser for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from enum import IntEnum
|
||||
from typing import cast
|
||||
from typing import Union
|
||||
|
||||
from .tokenize import _Input
|
||||
from .tokenize import json_token_type_to_string
|
||||
from .tokenize import JsonTokenType
|
||||
from .tokenize import Tokenizer
|
||||
|
||||
|
||||
# Type definitions for JSON values
|
||||
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
|
||||
JsonObject = dict[str, JsonValue]
|
||||
|
||||
|
||||
class _StateEnum(IntEnum):
|
||||
"""Parser state machine states"""
|
||||
|
||||
Initial = 0
|
||||
InString = 1
|
||||
InArray = 2
|
||||
InObjectExpectingKey = 3
|
||||
InObjectExpectingValue = 4
|
||||
|
||||
|
||||
class _State:
|
||||
"""Base class for parser states"""
|
||||
|
||||
type: _StateEnum
|
||||
value: JsonValue | tuple[str, JsonObject] | None
|
||||
|
||||
|
||||
class _InitialState(_State):
|
||||
"""Initial state before any parsing"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.Initial
|
||||
self.value = None
|
||||
|
||||
|
||||
class _InStringState(_State):
|
||||
"""State while parsing a string"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InString
|
||||
self.value = ""
|
||||
|
||||
|
||||
class _InArrayState(_State):
|
||||
"""State while parsing an array"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InArray
|
||||
self.value: list[JsonValue] = []
|
||||
|
||||
|
||||
class _InObjectExpectingKeyState(_State):
|
||||
"""State while parsing an object, expecting a key"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingKey
|
||||
self.value: JsonObject = {}
|
||||
|
||||
|
||||
class _InObjectExpectingValueState(_State):
|
||||
"""State while parsing an object, expecting a value"""
|
||||
|
||||
def __init__(self, key: str, obj: JsonObject) -> None:
|
||||
self.type = _StateEnum.InObjectExpectingValue
|
||||
self.value = (key, obj)
|
||||
|
||||
|
||||
# Sentinel value to distinguish "not set" from "set to None/null"
|
||||
class _Unset:
|
||||
pass
|
||||
|
||||
|
||||
_UNSET = _Unset()
|
||||
|
||||
|
||||
class _Parser:
|
||||
"""
|
||||
Incremental JSON parser
|
||||
|
||||
Feed chunks of JSON text via feed() and get back progressively
|
||||
more complete JSON values.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._state_stack: list[_State] = [_InitialState()]
|
||||
self._toplevel_value: JsonValue | _Unset = _UNSET
|
||||
self._input = _Input()
|
||||
self.tokenizer = Tokenizer(self._input, self)
|
||||
self._finished = False
|
||||
self._progressed = False
|
||||
self._prev_snapshot: JsonValue | _Unset = _UNSET
|
||||
|
||||
def feed(self, chunk: str) -> list[JsonValue]:
|
||||
"""
|
||||
Feed a chunk of JSON text and return deltas from the previous state.
|
||||
|
||||
Each element in the returned list represents what changed since the
|
||||
last yielded value. For dicts, only changed/new keys are included,
|
||||
with string values containing only the newly appended characters.
|
||||
"""
|
||||
if self._finished:
|
||||
return []
|
||||
|
||||
self._input.feed(chunk)
|
||||
return self._collect_deltas()
|
||||
|
||||
@staticmethod
|
||||
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
|
||||
if prev is None:
|
||||
return current
|
||||
|
||||
if isinstance(current, dict) and isinstance(prev, dict):
|
||||
result: JsonObject = {}
|
||||
for key in current:
|
||||
cur_val = current[key]
|
||||
prev_val = prev.get(key)
|
||||
if key not in prev:
|
||||
result[key] = cur_val
|
||||
elif isinstance(cur_val, str) and isinstance(prev_val, str):
|
||||
if cur_val != prev_val:
|
||||
result[key] = cur_val[len(prev_val) :]
|
||||
elif isinstance(cur_val, list) and isinstance(prev_val, list):
|
||||
if cur_val != prev_val:
|
||||
new_items = cur_val[len(prev_val) :]
|
||||
# check if the last existing element was updated
|
||||
if (
|
||||
prev_val
|
||||
and len(cur_val) >= len(prev_val)
|
||||
and cur_val[len(prev_val) - 1] != prev_val[-1]
|
||||
):
|
||||
result[key] = [cur_val[len(prev_val) - 1]] + new_items
|
||||
elif new_items:
|
||||
result[key] = new_items
|
||||
elif cur_val != prev_val:
|
||||
result[key] = cur_val
|
||||
return result if result else None
|
||||
|
||||
if isinstance(current, str) and isinstance(prev, str):
|
||||
delta = current[len(prev) :]
|
||||
return delta if delta else None
|
||||
|
||||
if isinstance(current, list) and isinstance(prev, list):
|
||||
if current != prev:
|
||||
new_items = current[len(prev) :]
|
||||
if (
|
||||
prev
|
||||
and len(current) >= len(prev)
|
||||
and current[len(prev) - 1] != prev[-1]
|
||||
):
|
||||
return [current[len(prev) - 1]] + new_items
|
||||
return new_items if new_items else None
|
||||
return None
|
||||
|
||||
if current != prev:
|
||||
return current
|
||||
return None
|
||||
|
||||
def finish(self) -> list[JsonValue]:
|
||||
"""Signal that no more chunks will be fed. Validates trailing content.
|
||||
|
||||
Returns any final deltas produced by flushing pending tokens (e.g.
|
||||
numbers, which have no terminator and wait for more input).
|
||||
"""
|
||||
self._input.mark_complete()
|
||||
# Pump once more so the tokenizer can emit tokens that were waiting
|
||||
# for more input (e.g. numbers need buffer_complete to finalize).
|
||||
results = self._collect_deltas()
|
||||
self._input.expect_end_of_content()
|
||||
return results
|
||||
|
||||
def _collect_deltas(self) -> list[JsonValue]:
|
||||
"""Run one pump cycle and return any deltas produced."""
|
||||
results: list[JsonValue] = []
|
||||
while True:
|
||||
self._progressed = False
|
||||
self.tokenizer.pump()
|
||||
|
||||
if self._progressed:
|
||||
if self._toplevel_value is _UNSET:
|
||||
raise RuntimeError(
|
||||
"Internal error: toplevel_value should not be unset "
|
||||
"after progressing"
|
||||
)
|
||||
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
|
||||
if isinstance(self._prev_snapshot, _Unset):
|
||||
results.append(current)
|
||||
else:
|
||||
delta = self._compute_delta(self._prev_snapshot, current)
|
||||
if delta is not None:
|
||||
results.append(delta)
|
||||
self._prev_snapshot = current
|
||||
else:
|
||||
if not self._state_stack:
|
||||
self._finished = True
|
||||
break
|
||||
return results
|
||||
|
||||
# TokenHandler protocol implementation
|
||||
|
||||
def handle_null(self) -> None:
|
||||
"""Handle null token"""
|
||||
self._handle_value_token(JsonTokenType.Null, None)
|
||||
|
||||
def handle_boolean(self, value: bool) -> None:
|
||||
"""Handle boolean token"""
|
||||
self._handle_value_token(JsonTokenType.Boolean, value)
|
||||
|
||||
def handle_number(self, value: float) -> None:
|
||||
"""Handle number token"""
|
||||
self._handle_value_token(JsonTokenType.Number, value)
|
||||
|
||||
def handle_string_start(self) -> None:
|
||||
"""Handle string start token"""
|
||||
state = self._current_state()
|
||||
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(JsonTokenType.StringStart, None)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
self._state_stack.append(_InStringState())
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
sv = self._progress_value(JsonTokenType.StringStart, None)
|
||||
obj[key] = sv
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
def handle_string_middle(self, value: str) -> None:
|
||||
"""Handle string middle token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
if len(self._state_stack) >= 2:
|
||||
prev = self._state_stack[-2]
|
||||
if prev.type != _StateEnum.InObjectExpectingKey:
|
||||
self._progressed = True
|
||||
else:
|
||||
self._progressed = True
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
assert isinstance(state.value, str)
|
||||
state.value += value
|
||||
|
||||
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_string_end(self) -> None:
|
||||
"""Handle string end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type != _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
|
||||
f"token when not in string"
|
||||
)
|
||||
|
||||
self._state_stack.pop()
|
||||
parent_state = self._state_stack[-1] if self._state_stack else None
|
||||
assert isinstance(state.value, str)
|
||||
self._update_string_parent(state.value, parent_state)
|
||||
|
||||
def handle_array_start(self) -> None:
|
||||
"""Handle array start token"""
|
||||
self._handle_value_token(JsonTokenType.ArrayStart, None)
|
||||
|
||||
def handle_array_end(self) -> None:
|
||||
"""Handle array end token"""
|
||||
state = self._current_state()
|
||||
if state.type != _StateEnum.InArray:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
|
||||
)
|
||||
self._state_stack.pop()
|
||||
|
||||
def handle_object_start(self) -> None:
|
||||
"""Handle object start token"""
|
||||
self._handle_value_token(JsonTokenType.ObjectStart, None)
|
||||
|
||||
def handle_object_end(self) -> None:
|
||||
"""Handle object end token"""
|
||||
state = self._current_state()
|
||||
|
||||
if state.type in (
|
||||
_StateEnum.InObjectExpectingKey,
|
||||
_StateEnum.InObjectExpectingValue,
|
||||
):
|
||||
self._state_stack.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
|
||||
)
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _current_state(self) -> _State:
|
||||
"""Get current parser state"""
|
||||
if not self._state_stack:
|
||||
raise ValueError("Unexpected trailing input")
|
||||
return self._state_stack[-1]
|
||||
|
||||
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
|
||||
"""Handle a complete value token"""
|
||||
state = self._current_state()
|
||||
|
||||
if not self._progressed:
|
||||
self._progressed = True
|
||||
|
||||
if state.type == _StateEnum.Initial:
|
||||
self._state_stack.pop()
|
||||
self._toplevel_value = self._progress_value(token_type, value)
|
||||
|
||||
elif state.type == _StateEnum.InArray:
|
||||
v = self._progress_value(token_type, value)
|
||||
arr = cast(list[JsonValue], state.value)
|
||||
arr.append(v)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], state.value)
|
||||
if token_type != JsonTokenType.StringStart:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
v = self._progress_value(token_type, value)
|
||||
obj[key] = v
|
||||
|
||||
elif state.type == _StateEnum.InString:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of string"
|
||||
)
|
||||
|
||||
elif state.type == _StateEnum.InObjectExpectingKey:
|
||||
raise ValueError(
|
||||
f"Unexpected {json_token_type_to_string(token_type)} "
|
||||
f"token in the middle of object expecting key"
|
||||
)
|
||||
|
||||
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
|
||||
"""Update parent container with updated string value"""
|
||||
if parent_state is None:
|
||||
self._toplevel_value = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InArray:
|
||||
arr = cast(list[JsonValue], parent_state.value)
|
||||
arr[-1] = updated
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingValue:
|
||||
key, obj = cast(tuple[str, JsonObject], parent_state.value)
|
||||
obj[key] = updated
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
new_state = _InObjectExpectingKeyState()
|
||||
new_state.value = obj
|
||||
self._state_stack.append(new_state)
|
||||
|
||||
elif parent_state.type == _StateEnum.InObjectExpectingKey:
|
||||
if self._state_stack and self._state_stack[-1] == parent_state:
|
||||
self._state_stack.pop()
|
||||
obj = cast(JsonObject, parent_state.value)
|
||||
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
|
||||
|
||||
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
|
||||
"""Create initial value for a token and push appropriate state"""
|
||||
if token_type == JsonTokenType.Null:
|
||||
return None
|
||||
|
||||
elif token_type == JsonTokenType.Boolean:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.Number:
|
||||
return value
|
||||
|
||||
elif token_type == JsonTokenType.StringStart:
|
||||
string_state = _InStringState()
|
||||
self._state_stack.append(string_state)
|
||||
return ""
|
||||
|
||||
elif token_type == JsonTokenType.ArrayStart:
|
||||
array_state = _InArrayState()
|
||||
self._state_stack.append(array_state)
|
||||
return array_state.value
|
||||
|
||||
elif token_type == JsonTokenType.ObjectStart:
|
||||
object_state = _InObjectExpectingKeyState()
|
||||
self._state_stack.append(object_state)
|
||||
return object_state.value
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected token type: {json_token_type_to_string(token_type)}"
|
||||
)
|
||||
@@ -1,514 +0,0 @@
|
||||
"""
|
||||
JSON tokenizer for streaming incremental parsing
|
||||
|
||||
Copyright (c) 2023 Google LLC (original TypeScript implementation)
|
||||
Copyright (c) 2024 jsonriver-python contributors (Python port)
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import IntEnum
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenHandler(Protocol):
|
||||
"""Protocol for handling JSON tokens"""
|
||||
|
||||
def handle_null(self) -> None: ...
|
||||
def handle_boolean(self, value: bool) -> None: ...
|
||||
def handle_number(self, value: float) -> None: ...
|
||||
def handle_string_start(self) -> None: ...
|
||||
def handle_string_middle(self, value: str) -> None: ...
|
||||
def handle_string_end(self) -> None: ...
|
||||
def handle_array_start(self) -> None: ...
|
||||
def handle_array_end(self) -> None: ...
|
||||
def handle_object_start(self) -> None: ...
|
||||
def handle_object_end(self) -> None: ...
|
||||
|
||||
|
||||
class JsonTokenType(IntEnum):
|
||||
"""Types of JSON tokens"""
|
||||
|
||||
Null = 0
|
||||
Boolean = 1
|
||||
Number = 2
|
||||
StringStart = 3
|
||||
StringMiddle = 4
|
||||
StringEnd = 5
|
||||
ArrayStart = 6
|
||||
ArrayEnd = 7
|
||||
ObjectStart = 8
|
||||
ObjectEnd = 9
|
||||
|
||||
|
||||
def json_token_type_to_string(token_type: JsonTokenType) -> str:
|
||||
"""Convert token type to readable string"""
|
||||
names = {
|
||||
JsonTokenType.Null: "null",
|
||||
JsonTokenType.Boolean: "boolean",
|
||||
JsonTokenType.Number: "number",
|
||||
JsonTokenType.StringStart: "string start",
|
||||
JsonTokenType.StringMiddle: "string middle",
|
||||
JsonTokenType.StringEnd: "string end",
|
||||
JsonTokenType.ArrayStart: "array start",
|
||||
JsonTokenType.ArrayEnd: "array end",
|
||||
JsonTokenType.ObjectStart: "object start",
|
||||
JsonTokenType.ObjectEnd: "object end",
|
||||
}
|
||||
return names[token_type]
|
||||
|
||||
|
||||
class _State(IntEnum):
|
||||
"""Internal tokenizer states"""
|
||||
|
||||
ExpectingValue = 0
|
||||
InString = 1
|
||||
StartArray = 2
|
||||
AfterArrayValue = 3
|
||||
StartObject = 4
|
||||
AfterObjectKey = 5
|
||||
AfterObjectValue = 6
|
||||
BeforeObjectKey = 7
|
||||
|
||||
|
||||
# Regex for validating JSON numbers
|
||||
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
|
||||
|
||||
|
||||
def _parse_json_number(s: str) -> float:
|
||||
"""Parse a JSON number string, validating format"""
|
||||
if not _JSON_NUMBER_PATTERN.match(s):
|
||||
raise ValueError("Invalid number")
|
||||
return float(s)
|
||||
|
||||
|
||||
class _Input:
|
||||
"""
|
||||
Input buffer for chunk-based JSON parsing
|
||||
|
||||
Manages buffering of input chunks and provides methods for
|
||||
consuming and inspecting the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._buffer = ""
|
||||
self._start_index = 0
|
||||
self.buffer_complete = False
|
||||
|
||||
def feed(self, chunk: str) -> None:
|
||||
"""Add a chunk of data to the buffer"""
|
||||
self._buffer += chunk
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Signal that no more chunks will be fed"""
|
||||
self.buffer_complete = True
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
"""Number of characters remaining in buffer"""
|
||||
return len(self._buffer) - self._start_index
|
||||
|
||||
def advance(self, length: int) -> None:
|
||||
"""Advance the start position by length characters"""
|
||||
self._start_index += length
|
||||
|
||||
def peek(self, offset: int) -> str | None:
|
||||
"""Peek at character at offset, or None if not available"""
|
||||
idx = self._start_index + offset
|
||||
if idx < len(self._buffer):
|
||||
return self._buffer[idx]
|
||||
return None
|
||||
|
||||
def peek_char_code(self, offset: int) -> int:
|
||||
"""Get character code at offset"""
|
||||
return ord(self._buffer[self._start_index + offset])
|
||||
|
||||
def slice(self, start: int, end: int) -> str:
|
||||
"""Slice buffer from start to end (relative to current position)"""
|
||||
return self._buffer[self._start_index + start : self._start_index + end]
|
||||
|
||||
def commit(self) -> None:
|
||||
"""Commit consumed content, removing it from buffer"""
|
||||
if self._start_index > 0:
|
||||
self._buffer = self._buffer[self._start_index :]
|
||||
self._start_index = 0
|
||||
|
||||
def remaining(self) -> str:
|
||||
"""Get all remaining content in buffer"""
|
||||
return self._buffer[self._start_index :]
|
||||
|
||||
def expect_end_of_content(self) -> None:
|
||||
"""Verify no non-whitespace content remains"""
|
||||
self.commit()
|
||||
self.skip_past_whitespace()
|
||||
if self.length != 0:
|
||||
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
|
||||
|
||||
def skip_past_whitespace(self) -> None:
|
||||
"""Skip whitespace characters"""
|
||||
i = self._start_index
|
||||
while i < len(self._buffer):
|
||||
c = ord(self._buffer[i])
|
||||
if c in (32, 9, 10, 13): # space, tab, \n, \r
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
self._start_index = i
|
||||
|
||||
def try_to_take_prefix(self, prefix: str) -> bool:
|
||||
"""Try to consume prefix from buffer, return True if successful"""
|
||||
if self._buffer.startswith(prefix, self._start_index):
|
||||
self._start_index += len(prefix)
|
||||
return True
|
||||
return False
|
||||
|
||||
def try_to_take(self, length: int) -> str | None:
|
||||
"""Try to take length characters, or None if not enough available"""
|
||||
if self.length < length:
|
||||
return None
|
||||
result = self._buffer[self._start_index : self._start_index + length]
|
||||
self._start_index += length
|
||||
return result
|
||||
|
||||
def try_to_take_char_code(self) -> int | None:
|
||||
"""Try to take a single character as char code, or None if buffer empty"""
|
||||
if self.length == 0:
|
||||
return None
|
||||
code = ord(self._buffer[self._start_index])
|
||||
self._start_index += 1
|
||||
return code
|
||||
|
||||
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
|
||||
"""
|
||||
Consume input up to first quote or backslash
|
||||
|
||||
Returns tuple of (consumed_content, pattern_found)
|
||||
"""
|
||||
buf = self._buffer
|
||||
i = self._start_index
|
||||
while i < len(buf):
|
||||
c = ord(buf[i])
|
||||
if c <= 0x1F:
|
||||
raise ValueError("Unescaped control character in string")
|
||||
if c == 34 or c == 92: # " or \
|
||||
result = buf[self._start_index : i]
|
||||
self._start_index = i
|
||||
return (result, True)
|
||||
i += 1
|
||||
|
||||
result = buf[self._start_index :]
|
||||
self._start_index = len(buf)
|
||||
return (result, False)
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""
|
||||
Tokenizer for chunk-based JSON parsing
|
||||
|
||||
Processes chunks fed into its input buffer and calls handler methods
|
||||
as JSON tokens are recognized.
|
||||
"""
|
||||
|
||||
def __init__(self, input: _Input, handler: TokenHandler) -> None:
|
||||
self.input = input
|
||||
self._handler = handler
|
||||
self._stack: list[_State] = [_State.ExpectingValue]
|
||||
self._emitted_tokens = 0
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if tokenization is complete"""
|
||||
return len(self._stack) == 0 and self.input.length == 0
|
||||
|
||||
def pump(self) -> None:
|
||||
"""Process all available tokens in the buffer"""
|
||||
while True:
|
||||
before = self._emitted_tokens
|
||||
self._tokenize_more()
|
||||
if self._emitted_tokens == before:
|
||||
self.input.commit()
|
||||
return
|
||||
|
||||
def _tokenize_more(self) -> None:
|
||||
"""Process one step of tokenization based on current state"""
|
||||
if not self._stack:
|
||||
return
|
||||
|
||||
state = self._stack[-1]
|
||||
|
||||
if state == _State.ExpectingValue:
|
||||
self._tokenize_value()
|
||||
elif state == _State.InString:
|
||||
self._tokenize_string()
|
||||
elif state == _State.StartArray:
|
||||
self._tokenize_array_start()
|
||||
elif state == _State.AfterArrayValue:
|
||||
self._tokenize_after_array_value()
|
||||
elif state == _State.StartObject:
|
||||
self._tokenize_object_start()
|
||||
elif state == _State.AfterObjectKey:
|
||||
self._tokenize_after_object_key()
|
||||
elif state == _State.AfterObjectValue:
|
||||
self._tokenize_after_object_value()
|
||||
elif state == _State.BeforeObjectKey:
|
||||
self._tokenize_before_object_key()
|
||||
|
||||
def _tokenize_value(self) -> None:
|
||||
"""Tokenize a JSON value"""
|
||||
self.input.skip_past_whitespace()
|
||||
|
||||
if self.input.try_to_take_prefix("null"):
|
||||
self._handler.handle_null()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("true"):
|
||||
self._handler.handle_boolean(True)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("false"):
|
||||
self._handler.handle_boolean(False)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.length > 0:
|
||||
ch = self.input.peek_char_code(0)
|
||||
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
|
||||
# Scan for end of number
|
||||
i = 0
|
||||
while i < self.input.length:
|
||||
c = self.input.peek_char_code(i)
|
||||
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if i == self.input.length and not self.input.buffer_complete:
|
||||
# Need more input (numbers have no terminator)
|
||||
return
|
||||
|
||||
number_chars = self.input.slice(0, i)
|
||||
self.input.advance(i)
|
||||
number = _parse_json_number(number_chars)
|
||||
self._handler.handle_number(number)
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix('"'):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("["):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartArray)
|
||||
self._handler.handle_array_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_array_start()
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("{"):
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.StartObject)
|
||||
self._handler.handle_object_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_object_start()
|
||||
return
|
||||
|
||||
def _tokenize_string(self) -> None:
|
||||
"""Tokenize string content"""
|
||||
while True:
|
||||
chunk, interrupted = self.input.take_until_quote_or_backslash()
|
||||
if chunk:
|
||||
self._handler.handle_string_middle(chunk)
|
||||
self._emitted_tokens += 1
|
||||
elif not interrupted:
|
||||
return
|
||||
|
||||
if interrupted:
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
next_char = self.input.peek(0)
|
||||
if next_char == '"':
|
||||
self.input.advance(1)
|
||||
self._handler.handle_string_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
# Handle escape sequences
|
||||
next_char2 = self.input.peek(1)
|
||||
if next_char2 is None:
|
||||
return
|
||||
|
||||
value: str
|
||||
if next_char2 == "u":
|
||||
# Unicode escape: need 4 hex digits
|
||||
if self.input.length < 6:
|
||||
return
|
||||
|
||||
code = 0
|
||||
for j in range(2, 6):
|
||||
c = self.input.peek_char_code(j)
|
||||
if 48 <= c <= 57: # 0-9
|
||||
digit = c - 48
|
||||
elif 65 <= c <= 70: # A-F
|
||||
digit = c - 55
|
||||
elif 97 <= c <= 102: # a-f
|
||||
digit = c - 87
|
||||
else:
|
||||
raise ValueError("Bad Unicode escape in JSON")
|
||||
code = (code << 4) | digit
|
||||
|
||||
self.input.advance(6)
|
||||
self._handler.handle_string_middle(chr(code))
|
||||
self._emitted_tokens += 1
|
||||
continue
|
||||
|
||||
elif next_char2 == "n":
|
||||
value = "\n"
|
||||
elif next_char2 == "r":
|
||||
value = "\r"
|
||||
elif next_char2 == "t":
|
||||
value = "\t"
|
||||
elif next_char2 == "b":
|
||||
value = "\b"
|
||||
elif next_char2 == "f":
|
||||
value = "\f"
|
||||
elif next_char2 == "\\":
|
||||
value = "\\"
|
||||
elif next_char2 == "/":
|
||||
value = "/"
|
||||
elif next_char2 == '"':
|
||||
value = '"'
|
||||
else:
|
||||
raise ValueError("Bad escape in string")
|
||||
|
||||
self.input.advance(2)
|
||||
self._handler.handle_string_middle(value)
|
||||
self._emitted_tokens += 1
|
||||
|
||||
def _tokenize_array_start(self) -> None:
|
||||
"""Tokenize start of array (check for empty or first element)"""
|
||||
self.input.skip_past_whitespace()
|
||||
if self.input.length == 0:
|
||||
return
|
||||
|
||||
if self.input.try_to_take_prefix("]"):
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterArrayValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
|
||||
def _tokenize_after_array_value(self) -> None:
|
||||
"""Tokenize after an array value (expect , or ])"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x5D: # ]
|
||||
self._handler.handle_array_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_object_start(self) -> None:
|
||||
"""Tokenize start of object (check for empty or first key)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_key(self) -> None:
|
||||
"""Tokenize after object key (expect :)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x3A: # :
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectValue)
|
||||
self._stack.append(_State.ExpectingValue)
|
||||
self._tokenize_value()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
|
||||
|
||||
def _tokenize_after_object_value(self) -> None:
|
||||
"""Tokenize after object value (expect , or })"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x7D: # }
|
||||
self._handler.handle_object_end()
|
||||
self._emitted_tokens += 1
|
||||
self._stack.pop()
|
||||
return
|
||||
elif next_char == 0x2C: # ,
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.BeforeObjectKey)
|
||||
self._tokenize_before_object_key()
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected , or }} after object value, got {chr(next_char)!r}"
|
||||
)
|
||||
|
||||
def _tokenize_before_object_key(self) -> None:
|
||||
"""Tokenize before object key (after comma)"""
|
||||
self.input.skip_past_whitespace()
|
||||
next_char = self.input.try_to_take_char_code()
|
||||
|
||||
if next_char is None:
|
||||
return
|
||||
elif next_char == 0x22: # "
|
||||
self._stack.pop()
|
||||
self._stack.append(_State.AfterObjectKey)
|
||||
self._stack.append(_State.InString)
|
||||
self._handler.handle_string_start()
|
||||
self._emitted_tokens += 1
|
||||
self._tokenize_string()
|
||||
return
|
||||
else:
|
||||
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
|
||||
@@ -128,8 +128,6 @@ class SensitiveValue(Generic[T]):
|
||||
value = self._decrypt()
|
||||
|
||||
if not apply_mask:
|
||||
# Callers must not mutate the returned dict — doing so would
|
||||
# desync the cache from the encrypted bytes and the DB.
|
||||
return value
|
||||
|
||||
# Apply masking
|
||||
@@ -176,20 +174,18 @@ class SensitiveValue(Generic[T]):
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Compare SensitiveValues by their decrypted content."""
|
||||
# NOTE: if you attempt to compare a string/dict to a SensitiveValue,
|
||||
# this comparison will return NotImplemented, which then evaluates to False.
|
||||
# This is the convention and required for SQLAlchemy's attribute tracking.
|
||||
if not isinstance(other, SensitiveValue):
|
||||
return NotImplemented
|
||||
return self._decrypt() == other._decrypt()
|
||||
"""Prevent direct comparison which might expose value."""
|
||||
if isinstance(other, SensitiveValue):
|
||||
# Compare encrypted bytes for equality check
|
||||
return self._encrypted_bytes == other._encrypted_bytes
|
||||
raise SensitiveAccessError(
|
||||
"Cannot compare SensitiveValue with non-SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value for comparison."
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on decrypted content."""
|
||||
value = self._decrypt()
|
||||
if isinstance(value, dict):
|
||||
return hash(json.dumps(value, sort_keys=True))
|
||||
return hash(value)
|
||||
"""Allow hashing based on encrypted bytes."""
|
||||
return hash(self._encrypted_bytes)
|
||||
|
||||
# Prevent JSON serialization
|
||||
def __json__(self) -> Any:
|
||||
|
||||
@@ -140,44 +140,6 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed.
|
||||
SSRFException: If URL fails SSRF checks.
|
||||
"""
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must contain a hostname")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFException("URLs with embedded credentials are not allowed.")
|
||||
|
||||
hostname = parsed.hostname.lower()
|
||||
if hostname in BLOCKED_HOSTNAMES:
|
||||
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
|
||||
|
||||
if not allow_private_network:
|
||||
_validate_and_resolve_url(normalized_url)
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
@@ -24,9 +24,6 @@ class OnyxVersion:
|
||||
def set_ee(self) -> None:
|
||||
self._is_ee = True
|
||||
|
||||
def unset_ee(self) -> None:
|
||||
self._is_ee = False
|
||||
|
||||
def is_ee_version(self) -> bool:
|
||||
return self._is_ee
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
@@ -1,175 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
@@ -1,579 +0,0 @@
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
raw_speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.speech_region = self._validate_speech_region(raw_speech_region)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_cloud_url(uri: str | None) -> bool:
|
||||
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
|
||||
if not uri:
|
||||
return False
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return False
|
||||
return hostname.endswith(
|
||||
(
|
||||
".speech.microsoft.com",
|
||||
".api.cognitive.microsoft.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stt_tts_match = re.match(
|
||||
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
|
||||
)
|
||||
if stt_tts_match:
|
||||
return stt_tts_match.group(1)
|
||||
|
||||
api_match = re.match(
|
||||
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
|
||||
)
|
||||
if api_match:
|
||||
return api_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_speech_region(speech_region: str) -> str:
|
||||
normalized_region = speech_region.strip().lower()
|
||||
if not normalized_region:
|
||||
return ""
|
||||
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
|
||||
raise ValueError(
|
||||
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
|
||||
)
|
||||
return normalized_region
|
||||
|
||||
def _get_stt_url(self) -> str:
|
||||
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
def _get_tts_url(self) -> str:
|
||||
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
def _is_self_hosted(self) -> bool:
|
||||
"""Check if using self-hosted container vs Azure cloud."""
|
||||
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT (cloud mode)")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = self._get_stt_url()
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS (cloud mode)")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = self._get_tts_url()
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming transcription (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming TTS (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,792 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Default ElevenLabs API base URL
|
||||
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = 24000, # What frontend sends
|
||||
target_sample_rate: int = 16000, # What ElevenLabs expects
|
||||
language_code: str = "en",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs=1.0"
|
||||
f"&vad_threshold=0.4"
|
||||
f"&min_speech_duration_ms=100"
|
||||
f"&min_silence_duration_ms=300"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == "error":
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle different message types from ElevenLabs Scribe API
|
||||
if msg_type == "session_started":
|
||||
# Session started successfully
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == "partial_transcript":
|
||||
# Partial transcript (interim result)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == "committed_transcript":
|
||||
# Final/committed transcript (VAD detected end of utterance)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "utterance_end":
|
||||
# VAD detected end of speech
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == "session_ended":
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
api_base: str | None = None,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self.speed = speed
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming
|
||||
# Note: API key is sent via header only (not in body to avoid log exposure)
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
"speed": self.speed,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": [
|
||||
120,
|
||||
160,
|
||||
250,
|
||||
290,
|
||||
], # Optimized chunk sizes for streaming
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or "https://api.elevenlabs.io"
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75,
|
||||
"speed": speed,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model
|
||||
# Frontend sends PCM16 at 24kHz, but ElevenLabs expects 16kHz
|
||||
# The transcriber will resample automatically
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=24000, # What frontend sends
|
||||
target_sample_rate=16000, # What ElevenLabs expects
|
||||
language_code="en",
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
# Use mp3_44100_64 for streaming - good balance of quality and chunk size
|
||||
output_format="mp3_44100_64",
|
||||
api_base=self.api_base,
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,590 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Default OpenAI API base URL
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "whisper-1",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = f"{ws_base}/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == "error":
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == "input_audio_buffer.speech_started":
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.speech_stopped":
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == "input_audio_buffer.committed":
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == "conversation.item.input_audio_transcription.delta":
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif (
|
||||
msg_type
|
||||
== "conversation.item.input_audio_transcription.completed"
|
||||
):
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
"transcription_session.created",
|
||||
"transcription_session.updated",
|
||||
"conversation.item.created",
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -67,8 +67,6 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
|
||||
@@ -1,93 +1,48 @@
|
||||
"""Decrypt a raw hex-encoded credential value.
|
||||
|
||||
Usage:
|
||||
python -m scripts.decrypt <hex_value>
|
||||
python -m scripts.decrypt <hex_value> --key "my-encryption-key"
|
||||
python -m scripts.decrypt <hex_value> --key ""
|
||||
|
||||
Pass --key "" to skip decryption and just decode the raw bytes as UTF-8.
|
||||
Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import binascii
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
|
||||
|
||||
def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value.
|
||||
def decrypt_raw_credential(encrypted_value: str) -> None:
|
||||
"""Decrypt and display a raw encrypted credential value
|
||||
|
||||
Args:
|
||||
encrypted_value: The hex-encoded encrypted credential value.
|
||||
key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET,
|
||||
empty string means just decode as UTF-8.
|
||||
encrypted_value: The hex encoded encrypted credential value
|
||||
"""
|
||||
# Strip common hex prefixes
|
||||
if encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
elif encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
print(encrypted_value)
|
||||
|
||||
try:
|
||||
raw_bytes = binascii.unhexlify(encrypted_value)
|
||||
# If string starts with 'x', remove it as it's just a prefix indicating hex
|
||||
if encrypted_value.startswith("x"):
|
||||
encrypted_value = encrypted_value[1:]
|
||||
elif encrypted_value.startswith("\\x"):
|
||||
encrypted_value = encrypted_value[2:]
|
||||
|
||||
# Convert hex string to bytes
|
||||
encrypted_bytes = binascii.unhexlify(encrypted_value)
|
||||
|
||||
# Decrypt the bytes
|
||||
decrypted_str = decrypt_bytes_to_string(encrypted_bytes)
|
||||
|
||||
# Parse and pretty print the decrypted JSON
|
||||
decrypted_json = json.loads(decrypted_str)
|
||||
print("Decrypted credential value:")
|
||||
print(json.dumps(decrypted_json, indent=2))
|
||||
|
||||
except binascii.Error:
|
||||
print("Error: Invalid hex-encoded string")
|
||||
sys.exit(1)
|
||||
print("Error: Invalid hex encoded string")
|
||||
|
||||
if key == "":
|
||||
# Empty key → just decode as UTF-8, no decryption
|
||||
try:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
print(f"Error decoding bytes as UTF-8: {e}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(key)
|
||||
try:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key)
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
sys.exit(1)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Decrypted raw value (not JSON): {e}")
|
||||
|
||||
# Try to pretty-print as JSON, otherwise print raw
|
||||
try:
|
||||
parsed = json.loads(decrypted_str)
|
||||
print(json.dumps(parsed, indent=2))
|
||||
except json.JSONDecodeError:
|
||||
print(decrypted_str)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Decrypt a hex-encoded credential value."
|
||||
)
|
||||
parser.add_argument(
|
||||
"value",
|
||||
help="Hex-encoded encrypted value to decrypt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--key",
|
||||
default=None,
|
||||
help=(
|
||||
"Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. "
|
||||
'Pass "" (empty) to just decode as UTF-8 without decryption.'
|
||||
),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
global_version.set_ee()
|
||||
decrypt_raw_credential(args.value, key=args.key)
|
||||
global_version.unset_ee()
|
||||
except Exception as e:
|
||||
print(f"Error decrypting value: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python decrypt.py <hex_encoded_encrypted_value>")
|
||||
sys.exit(1)
|
||||
|
||||
encrypted_value = sys.argv[1]
|
||||
decrypt_raw_credential(encrypted_value)
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Decrypts all encrypted columns using the old key (or raw decode if the old key
|
||||
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Usage (docker):
|
||||
docker exec -it onyx-api_server-1 \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Usage (kubernetes):
|
||||
kubectl exec -it <pod> -- \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Omit --old-key (or pass "") if secrets were not previously encrypted.
|
||||
|
||||
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
|
||||
or --all-tenants to iterate every tenant.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
|
||||
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
|
||||
|
||||
|
||||
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
|
||||
print(f"Re-encrypting secrets for tenant: {tenant_id}")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
|
||||
|
||||
if results:
|
||||
for col, count in results.items():
|
||||
print(
|
||||
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
|
||||
)
|
||||
else:
|
||||
print("No rows needed re-encryption.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Re-encrypt secrets under the current encryption key."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old-key",
|
||||
default=None,
|
||||
help="Previous encryption key. Omit or pass empty string if not applicable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be re-encrypted without making changes.",
|
||||
)
|
||||
|
||||
tenant_group = parser.add_mutually_exclusive_group()
|
||||
tenant_group.add_argument(
|
||||
"--tenant-id",
|
||||
default=None,
|
||||
help="Target a specific tenant schema.",
|
||||
)
|
||||
tenant_group.add_argument(
|
||||
"--all-tenants",
|
||||
action="store_true",
|
||||
help="Iterate all tenants.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_key = args.old_key if args.old_key else None
|
||||
|
||||
global_version.set_ee()
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN — no changes will be made")
|
||||
|
||||
if args.all_tenants:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
print(f"Found {len(tenant_ids)} tenant(s)")
|
||||
failed_tenants: list[str] = []
|
||||
for tid in tenant_ids:
|
||||
try:
|
||||
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
print(f" ERROR for tenant {tid}: {e}")
|
||||
failed_tenants.append(tid)
|
||||
if failed_tenants:
|
||||
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,71 +0,0 @@
|
||||
# Backend Tests
|
||||
|
||||
## Test Types
|
||||
|
||||
There are four test categories, ordered by increasing scope:
|
||||
|
||||
### Unit Tests (`tests/unit/`)
|
||||
|
||||
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
|
||||
logic (e.g. citation processing, encryption).
|
||||
|
||||
```bash
|
||||
pytest -xv backend/tests/unit
|
||||
```
|
||||
|
||||
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
|
||||
|
||||
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
|
||||
application containers are not. Tests call functions directly and can mock selectively.
|
||||
|
||||
Use when you need a real database or real API calls but want control over setup.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
|
||||
```
|
||||
|
||||
### Integration Tests (`tests/integration/`)
|
||||
|
||||
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
|
||||
|
||||
```bash
|
||||
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
|
||||
```
|
||||
|
||||
### Playwright / E2E Tests (`web/tests/e2e/`)
|
||||
|
||||
Full stack including web server. Use for frontend-backend coordination.
|
||||
|
||||
```bash
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
## Shared Fixtures
|
||||
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Enables EE mode for a test, with proper teardown and cache clearing.
|
||||
|
||||
```python
|
||||
# Whole file (in a test module, NOT in conftest.py)
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Whole directory — add an autouse wrapper to the directory's conftest.py
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
|
||||
# Single test
|
||||
def test_something(enable_ee: None) -> None: ...
|
||||
```
|
||||
|
||||
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
|
||||
directory — it only affects tests defined in the conftest itself (which is none).
|
||||
Use the autouse fixture wrapper pattern shown above instead.
|
||||
|
||||
Do NOT inline `global_version.set_ee()` — always use the fixture.
|
||||
@@ -1,24 +0,0 @@
|
||||
"""Root conftest — shared fixtures available to all test directories."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def enable_ee() -> Generator[None, None, None]:
|
||||
"""Temporarily enable EE mode for a single test.
|
||||
|
||||
Restores the previous EE state and clears the versioned-implementation
|
||||
cache on teardown so state doesn't leak between tests.
|
||||
"""
|
||||
was_ee = global_version.is_ee_version()
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
if not was_ee:
|
||||
global_version.unset_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
def test_confluence_connector_permissions(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
confluence_connector: ConfluenceConnector,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
|
||||
def test_confluence_connector_restriction_handling(
|
||||
mock_get_api_key: MagicMock, # noqa: ARG001
|
||||
mock_db_provider_class: MagicMock,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
# Test space key
|
||||
test_space_key = "DailyPermS"
|
||||
|
||||
@@ -4,6 +4,8 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
@@ -12,3 +14,14 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
@@ -98,7 +98,7 @@ def _build_connector(
|
||||
|
||||
def test_gdrive_perm_sync_with_real_data(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
|
||||
@@ -17,7 +19,16 @@ PRIVATE_CHANNEL_USERS = [
|
||||
"test_user_2@onyx-test.com",
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for these tests to work since
|
||||
perm syncing is a an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
|
||||
yield
|
||||
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from tests.daily.connectors.teams.models import TeamsThread
|
||||
from tests.daily.connectors.utils import load_all_from_connector
|
||||
|
||||
@@ -166,9 +168,18 @@ def test_slim_docs_retrieval_from_teams_connector(
|
||||
_assert_is_valid_external_access(external_access=slim_doc.external_access)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=False)
|
||||
def set_ee_on() -> Generator[None, None, None]:
|
||||
"""Need EE to be enabled for perm sync tests to work since
|
||||
perm syncing is an EE-only feature."""
|
||||
global_version.set_ee()
|
||||
yield
|
||||
global_version._is_ee = False
|
||||
|
||||
|
||||
def test_load_from_checkpoint_with_perm_sync(
|
||||
teams_connector: TeamsConnector,
|
||||
enable_ee: None, # noqa: ARG001
|
||||
set_ee_on: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -15,14 +14,13 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
# In order to get these tests to run, use the credentials from Bitwarden.
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
class DocExternalAccessSet(BaseModel):
|
||||
"""A version of DocExternalAccess that uses sets for comparison."""
|
||||
@@ -54,6 +52,9 @@ def test_jira_doc_sync(
|
||||
This test uses the AS project which has applicationRole permission,
|
||||
meaning all documents should be marked as public.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use AS project specifically for this test
|
||||
connector_config = {
|
||||
@@ -149,6 +150,9 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
This test uses a project that has specific user permissions to verify
|
||||
that specific users are correctly extracted.
|
||||
"""
|
||||
# NOTE: must set EE on or else the connector will skip the perm syncing
|
||||
global_version.set_ee()
|
||||
|
||||
try:
|
||||
# Use SUP project which has specific user permissions
|
||||
connector_config = {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
@@ -19,8 +18,6 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
|
||||
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
|
||||
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
# Expected groups from the danswerai.atlassian.net Jira instance
|
||||
# Note: These groups are shared with Confluence since they're both Atlassian products
|
||||
# App accounts (bots, integrations) are filtered out
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
"""Test that Credential with nested JSON round-trips through SensitiveValue correctly.
|
||||
|
||||
Exercises the full encrypt → store → read → decrypt → SensitiveValue path
|
||||
with realistic nested OAuth credential data, and verifies SQLAlchemy dirty
|
||||
tracking works with nested dict comparison.
|
||||
|
||||
Requires a running Postgres instance.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
# NOTE: this is not the real shape of a Drive credential,
|
||||
# but it is intended to test nested JSON credential handling
|
||||
|
||||
_NESTED_CRED_JSON = {
|
||||
"oauth_tokens": {
|
||||
"access_token": "ya29.abc123",
|
||||
"refresh_token": "1//xEg-def456",
|
||||
},
|
||||
"scopes": ["read", "write", "admin"],
|
||||
"client_config": {
|
||||
"client_id": "123.apps.googleusercontent.com",
|
||||
"client_secret": "GOCSPX-secret",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_nested_credential_json_round_trip(db_session: Session) -> None:
|
||||
"""Nested OAuth credential survives encrypt → store → read → decrypt."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Immediate read (no DB round-trip) — tests the set event wrapping
|
||||
assert isinstance(credential.credential_json, SensitiveValue)
|
||||
assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
# DB round-trip — tests process_result_value
|
||||
db_session.expire(credential)
|
||||
reloaded = credential.credential_json
|
||||
assert isinstance(reloaded, SensitiveValue)
|
||||
assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None:
|
||||
"""Re-assigning the same nested dict should not mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
# Clear dirty state from the insert
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
# Re-assign identical value
|
||||
credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment]
|
||||
assert not db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def test_assign_different_nested_json_is_dirty(db_session: Session) -> None:
|
||||
"""Assigning a different nested dict should mark the session dirty."""
|
||||
credential = Credential(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=_NESTED_CRED_JSON,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
|
||||
db_session.expire(credential)
|
||||
_ = credential.credential_json # force reload
|
||||
|
||||
modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]}
|
||||
credential.credential_json = modified_cred # type: ignore[assignment]
|
||||
assert db_session.is_modified(credential)
|
||||
|
||||
db_session.rollback()
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Tests for rotate_encryption_key against real Postgres.
|
||||
|
||||
Uses real ORM models (Credential, InternetSearchProvider) and the actual
|
||||
Postgres database. Discovery is mocked in rotation tests to scope mutations
|
||||
to only the test rows — the real _discover_encrypted_columns walk is tested
|
||||
separately in TestDiscoverEncryptedColumns.
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
|
||||
|
||||
OLD_KEY = "o" * 16
|
||||
NEW_KEY = "n" * 16
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee() -> Generator[None, None, None]:
|
||||
prev = global_version._is_ee
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
global_version._is_ee = prev
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
|
||||
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
|
||||
col = Credential.__table__.c.credential_json
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
Credential.__table__.c.id == credential_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
|
||||
"""Read raw bytes from InternetSearchProvider.api_key."""
|
||||
col = InternetSearchProvider.__table__.c.api_key
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
InternetSearchProvider.__table__.c.id == isp_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
class TestDiscoverEncryptedColumns:
|
||||
"""Verify _discover_encrypted_columns finds real production models."""
|
||||
|
||||
def test_discovers_credential_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("credential", "credential_json", True) in found
|
||||
|
||||
def test_discovers_internet_search_provider_api_key(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("internet_search_provider", "api_key", False) in found
|
||||
|
||||
def test_all_encrypted_string_columns_are_not_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedString):
|
||||
assert not is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
def test_all_encrypted_json_columns_are_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
assert is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
|
||||
class TestRotateCredential:
|
||||
"""Test rotation against the real Credential table (EncryptedJson).
|
||||
|
||||
Discovery is scoped to only the Credential model to avoid mutating
|
||||
other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[(Credential, "credential_json", ["id"], True)],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def credential_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert a Credential row with raw encrypted bytes, clean up after."""
|
||||
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
|
||||
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO credential "
|
||||
"(source, credential_json, admin_public, curator_public) "
|
||||
"VALUES (:source, :cred_json, true, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
|
||||
)
|
||||
cred_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield cred_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_credential_json(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
assert decrypted["endpoint"] == "https://example.com"
|
||||
|
||||
def test_skips_already_rotated(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
|
||||
def test_dry_run_does_not_modify(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
original = _raw_credential_bytes(db_session, credential_id)
|
||||
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw_after = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw_after == original
|
||||
|
||||
|
||||
class TestRotateInternetSearchProvider:
|
||||
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
|
||||
|
||||
Discovery is scoped to only the InternetSearchProvider model to avoid
|
||||
mutating other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[
|
||||
(InternetSearchProvider, "api_key", ["id"], False),
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def isp_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
|
||||
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-rotation-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": encrypted,
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield isp_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
|
||||
|
||||
def test_rotates_from_unencrypted(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> None:
|
||||
"""Test rotating data that was stored without any encryption key."""
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-raw-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": b"raw-api-key",
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=None)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Tests that SlackBot CRUD operations return properly typed SensitiveValue fields.
|
||||
|
||||
Regression test for the bug where insert_slack_bot/update_slack_bot returned
|
||||
objects with raw string tokens instead of SensitiveValue wrappers, causing
|
||||
'str object has no attribute get_value' errors in SlackBot.from_model().
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.slack_bot import insert_slack_bot
|
||||
from onyx.db.slack_bot import update_slack_bot
|
||||
from onyx.server.manage.models import SlackBot
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
def _unique(prefix: str) -> str:
|
||||
return f"{prefix}-{uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def test_insert_slack_bot_returns_sensitive_values(db_session: Session) -> None:
|
||||
bot_token = _unique("xoxb-insert")
|
||||
app_token = _unique("xapp-insert")
|
||||
user_token = _unique("xoxp-insert")
|
||||
|
||||
slack_bot = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
name=_unique("test-bot-insert"),
|
||||
enabled=True,
|
||||
bot_token=bot_token,
|
||||
app_token=app_token,
|
||||
user_token=user_token,
|
||||
)
|
||||
|
||||
assert isinstance(slack_bot.bot_token, SensitiveValue)
|
||||
assert isinstance(slack_bot.app_token, SensitiveValue)
|
||||
assert isinstance(slack_bot.user_token, SensitiveValue)
|
||||
|
||||
assert slack_bot.bot_token.get_value(apply_mask=False) == bot_token
|
||||
assert slack_bot.app_token.get_value(apply_mask=False) == app_token
|
||||
assert slack_bot.user_token.get_value(apply_mask=False) == user_token
|
||||
|
||||
# Verify from_model works without error
|
||||
pydantic_bot = SlackBot.from_model(slack_bot)
|
||||
assert pydantic_bot.bot_token # masked, but not empty
|
||||
assert pydantic_bot.app_token
|
||||
|
||||
|
||||
def test_update_slack_bot_returns_sensitive_values(db_session: Session) -> None:
|
||||
slack_bot = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
name=_unique("test-bot-update"),
|
||||
enabled=True,
|
||||
bot_token=_unique("xoxb-update"),
|
||||
app_token=_unique("xapp-update"),
|
||||
)
|
||||
|
||||
new_bot_token = _unique("xoxb-update-new")
|
||||
new_app_token = _unique("xapp-update-new")
|
||||
new_user_token = _unique("xoxp-update-new")
|
||||
|
||||
updated = update_slack_bot(
|
||||
db_session=db_session,
|
||||
slack_bot_id=slack_bot.id,
|
||||
name=_unique("test-bot-updated"),
|
||||
enabled=False,
|
||||
bot_token=new_bot_token,
|
||||
app_token=new_app_token,
|
||||
user_token=new_user_token,
|
||||
)
|
||||
|
||||
assert isinstance(updated.bot_token, SensitiveValue)
|
||||
assert isinstance(updated.app_token, SensitiveValue)
|
||||
assert isinstance(updated.user_token, SensitiveValue)
|
||||
|
||||
assert updated.bot_token.get_value(apply_mask=False) == new_bot_token
|
||||
assert updated.app_token.get_value(apply_mask=False) == new_app_token
|
||||
assert updated.user_token.get_value(apply_mask=False) == new_user_token
|
||||
|
||||
# Verify from_model works without error
|
||||
pydantic_bot = SlackBot.from_model(updated)
|
||||
assert pydantic_bot.bot_token
|
||||
assert pydantic_bot.app_token
|
||||
assert pydantic_bot.user_token is not None
|
||||
@@ -148,16 +148,8 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# Secrets should be preserved
|
||||
assert updated_config.client_id is not None
|
||||
assert original_client_id is not None
|
||||
assert updated_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_id.get_value(apply_mask=False)
|
||||
assert updated_config.client_secret is not None
|
||||
assert original_client_secret is not None
|
||||
assert updated_config.client_secret.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_secret.get_value(apply_mask=False)
|
||||
assert updated_config.client_id == original_client_id
|
||||
assert updated_config.client_secret == original_client_secret
|
||||
# But name should be updated
|
||||
assert updated_config.name == new_name
|
||||
|
||||
@@ -181,14 +173,9 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# client_id should be cleared (empty string)
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_id == ""
|
||||
# client_secret should be preserved
|
||||
assert updated_config.client_secret is not None
|
||||
assert original_client_secret is not None
|
||||
assert updated_config.client_secret.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_secret.get_value(apply_mask=False)
|
||||
assert updated_config.client_secret == original_client_secret
|
||||
|
||||
def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> None:
|
||||
"""Test clearing client_secret while preserving client_id"""
|
||||
@@ -203,14 +190,9 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# client_secret should be cleared (empty string)
|
||||
assert updated_config.client_secret is not None
|
||||
assert updated_config.client_secret.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_secret == ""
|
||||
# client_id should be preserved
|
||||
assert updated_config.client_id is not None
|
||||
assert original_client_id is not None
|
||||
assert updated_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
) == original_client_id.get_value(apply_mask=False)
|
||||
assert updated_config.client_id == original_client_id
|
||||
|
||||
def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> None:
|
||||
"""Test clearing both client_id and client_secret"""
|
||||
@@ -225,10 +207,8 @@ class TestOAuthConfigCRUD:
|
||||
)
|
||||
|
||||
# Both should be cleared (empty strings)
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_secret is not None
|
||||
assert updated_config.client_secret.get_value(apply_mask=False) == ""
|
||||
assert updated_config.client_id == ""
|
||||
assert updated_config.client_secret == ""
|
||||
|
||||
def test_update_oauth_config_authorization_url(self, db_session: Session) -> None:
|
||||
"""Test updating authorization_url"""
|
||||
@@ -295,8 +275,7 @@ class TestOAuthConfigCRUD:
|
||||
assert updated_config.token_url == new_token_url
|
||||
assert updated_config.scopes == new_scopes
|
||||
assert updated_config.additional_params == new_params
|
||||
assert updated_config.client_id is not None
|
||||
assert updated_config.client_id.get_value(apply_mask=False) == new_client_id
|
||||
assert updated_config.client_id == new_client_id
|
||||
|
||||
def test_delete_oauth_config(self, db_session: Session) -> None:
|
||||
"""Test deleting an OAuth configuration"""
|
||||
@@ -437,8 +416,7 @@ class TestOAuthUserTokenCRUD:
|
||||
assert user_token.id is not None
|
||||
assert user_token.oauth_config_id == oauth_config.id
|
||||
assert user_token.user_id == user.id
|
||||
assert user_token.token_data is not None
|
||||
assert user_token.token_data.get_value(apply_mask=False) == token_data
|
||||
assert user_token.token_data == token_data
|
||||
assert user_token.created_at is not None
|
||||
assert user_token.updated_at is not None
|
||||
|
||||
@@ -468,13 +446,8 @@ class TestOAuthUserTokenCRUD:
|
||||
|
||||
# Should be the same token record (updated, not inserted)
|
||||
assert updated_token.id == initial_token_id
|
||||
assert updated_token.token_data is not None
|
||||
assert (
|
||||
updated_token.token_data.get_value(apply_mask=False) == updated_token_data
|
||||
)
|
||||
assert (
|
||||
updated_token.token_data.get_value(apply_mask=False) != initial_token_data
|
||||
)
|
||||
assert updated_token.token_data == updated_token_data
|
||||
assert updated_token.token_data != initial_token_data
|
||||
|
||||
def test_get_user_oauth_token(self, db_session: Session) -> None:
|
||||
"""Test retrieving a user's OAuth token"""
|
||||
@@ -490,8 +463,7 @@ class TestOAuthUserTokenCRUD:
|
||||
|
||||
assert retrieved_token is not None
|
||||
assert retrieved_token.id == created_token.id
|
||||
assert retrieved_token.token_data is not None
|
||||
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data
|
||||
assert retrieved_token.token_data == token_data
|
||||
|
||||
def test_get_user_oauth_token_not_found(self, db_session: Session) -> None:
|
||||
"""Test retrieving a non-existent user token returns None"""
|
||||
@@ -547,8 +519,7 @@ class TestOAuthUserTokenCRUD:
|
||||
retrieved_token = get_user_oauth_token(oauth_config.id, user.id, db_session)
|
||||
assert retrieved_token is not None
|
||||
assert retrieved_token.id == updated_token.id
|
||||
assert retrieved_token.token_data is not None
|
||||
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data2
|
||||
assert retrieved_token.token_data == token_data2
|
||||
|
||||
def test_cascade_delete_user_tokens_on_config_deletion(
|
||||
self, db_session: Session
|
||||
|
||||
@@ -374,14 +374,8 @@ class TestOAuthTokenManagerCodeExchange:
|
||||
assert call_args[0][0] == oauth_config.token_url
|
||||
assert call_args[1]["data"]["grant_type"] == "authorization_code"
|
||||
assert call_args[1]["data"]["code"] == "auth_code_123"
|
||||
assert oauth_config.client_id is not None
|
||||
assert oauth_config.client_secret is not None
|
||||
assert call_args[1]["data"]["client_id"] == oauth_config.client_id.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
assert call_args[1]["data"][
|
||||
"client_secret"
|
||||
] == oauth_config.client_secret.get_value(apply_mask=False)
|
||||
assert call_args[1]["data"]["client_id"] == oauth_config.client_id
|
||||
assert call_args[1]["data"]["client_secret"] == oauth_config.client_secret
|
||||
assert call_args[1]["data"]["redirect_uri"] == "https://example.com/callback"
|
||||
|
||||
@patch("onyx.auth.oauth_token_manager.requests.post")
|
||||
|
||||
@@ -950,7 +950,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
@@ -1292,21 +1291,12 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
tool_call_id="call_replay_test",
|
||||
tool_call_argument_tokens=[json.dumps({"code": code})],
|
||||
)
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type="python",
|
||||
argument_deltas={"code": code},
|
||||
),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=False,
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
|
||||
@@ -64,8 +64,7 @@ class TestBotConfigAPI:
|
||||
db_session.commit()
|
||||
|
||||
assert config is not None
|
||||
assert config.bot_token is not None
|
||||
assert config.bot_token.get_value(apply_mask=False) == "test_token_123"
|
||||
assert config.bot_token == "test_token_123"
|
||||
|
||||
# Cleanup
|
||||
delete_discord_bot_config(db_session)
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Auto-enable EE mode for all tests under tests/unit/ee/."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
|
||||
"""Wraps the shared enable_ee fixture with autouse for this directory."""
|
||||
@@ -1,165 +0,0 @@
|
||||
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
|
||||
|
||||
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
|
||||
to the EE implementations, so no patching of the MIT layer is needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from ee.onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from ee.onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
|
||||
# Keys must be exactly 16, 24, or 32 bytes for AES
|
||||
KEY_16 = "a" * 16
|
||||
KEY_16_ALT = "b" * 16
|
||||
KEY_24 = "d" * 24
|
||||
KEY_32 = "c" * 32
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
def test_roundtrip_with_env_key(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("hello world")
|
||||
assert encrypted != b"hello world"
|
||||
assert _decrypt_bytes(encrypted) == "hello world"
|
||||
|
||||
def test_roundtrip_with_explicit_key(self) -> None:
|
||||
encrypted = _encrypt_string("secret data", key=KEY_32)
|
||||
assert encrypted != b"secret data"
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
|
||||
|
||||
def test_roundtrip_no_key(self) -> None:
|
||||
"""Without any key, data is raw-encoded (no encryption)."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
encrypted = _encrypt_string("plain text")
|
||||
assert encrypted == b"plain text"
|
||||
assert _decrypt_bytes(encrypted) == "plain text"
|
||||
|
||||
def test_explicit_key_overrides_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("data", key=KEY_16_ALT)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
|
||||
|
||||
def test_different_encryptions_produce_different_bytes(self) -> None:
|
||||
"""Each encryption uses a random IV, so results differ."""
|
||||
a = _encrypt_string("same", key=KEY_16)
|
||||
b = _encrypt_string("same", key=KEY_16)
|
||||
assert a != b
|
||||
|
||||
def test_roundtrip_empty_string(self) -> None:
|
||||
encrypted = _encrypt_string("", key=KEY_16)
|
||||
assert encrypted != b""
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
|
||||
|
||||
def test_roundtrip_unicode(self) -> None:
|
||||
text = "日本語テスト 🔐 émojis"
|
||||
encrypted = _encrypt_string(text, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == text
|
||||
|
||||
|
||||
class TestDecryptFallbackBehavior:
|
||||
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
|
||||
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
|
||||
raw = "readable text".encode()
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
assert _decrypt_bytes(raw) == "readable text"
|
||||
|
||||
def test_explicit_wrong_key_raises(self) -> None:
|
||||
"""Explicit key path: AES fails → raises, no fallback."""
|
||||
encrypted = _encrypt_string("secret", key=KEY_16)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16_ALT)
|
||||
|
||||
def test_explicit_none_key_with_no_env(self) -> None:
|
||||
"""key=None with empty env → raw decode."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
assert _decrypt_bytes(b"hello", key=None) == "hello"
|
||||
|
||||
def test_explicit_empty_string_key(self) -> None:
|
||||
"""key='' means no encryption."""
|
||||
encrypted = _encrypt_string("test", key="")
|
||||
assert encrypted == b"test"
|
||||
assert _decrypt_bytes(encrypted, key="") == "test"
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
def test_key_too_short_raises(self) -> None:
|
||||
with pytest.raises(RuntimeError, match="too short"):
|
||||
_encrypt_string("data", key="short")
|
||||
|
||||
def test_16_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
|
||||
|
||||
def test_24_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_24)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
|
||||
|
||||
def test_32_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_32)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
|
||||
|
||||
def test_long_key_truncated_to_32(self) -> None:
|
||||
"""Keys longer than 32 bytes are truncated to 32."""
|
||||
long_key = "e" * 64
|
||||
encrypted = _encrypt_string("data", key=long_key)
|
||||
assert _decrypt_bytes(encrypted, key=long_key) == "data"
|
||||
|
||||
def test_20_byte_key_trimmed_to_16(self) -> None:
|
||||
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
|
||||
key_20 = "f" * 20
|
||||
encrypted = _encrypt_string("data", key=key_20)
|
||||
assert _decrypt_bytes(encrypted, key=key_20) == "data"
|
||||
|
||||
# Verify it was trimmed to 16 by checking that the first 16 bytes
|
||||
# of the key can also decrypt it
|
||||
key_16_same_prefix = "f" * 16
|
||||
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
|
||||
|
||||
def test_25_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_25 = "g" * 25
|
||||
encrypted = _encrypt_string("data", key=key_25)
|
||||
assert _decrypt_bytes(encrypted, key=key_25) == "data"
|
||||
|
||||
key_24_same_prefix = "g" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
def test_30_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_30 = "h" * 30
|
||||
encrypted = _encrypt_string("data", key=key_30)
|
||||
assert _decrypt_bytes(encrypted, key=key_30) == "data"
|
||||
|
||||
key_24_same_prefix = "h" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
|
||||
class TestWrapperFunctions:
|
||||
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
|
||||
|
||||
With EE mode enabled, the wrappers resolve to EE implementations automatically.
|
||||
"""
|
||||
|
||||
def test_wrapper_passes_key(self) -> None:
|
||||
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
|
||||
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
|
||||
|
||||
def test_wrapper_no_key_uses_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
|
||||
encrypted = encrypt_string_to_bytes("payload")
|
||||
assert decrypt_bytes_to_string(encrypted) == "payload"
|
||||
@@ -1,630 +0,0 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _make_tool_call_delta(
|
||||
index: int = 0,
|
||||
name: str | None = None,
|
||||
arguments: str | None = None,
|
||||
function_is_none: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
|
||||
delta = MagicMock()
|
||||
delta.index = index
|
||||
if function_is_none:
|
||||
delta.function = None
|
||||
else:
|
||||
delta.function = MagicMock()
|
||||
delta.function.name = name
|
||||
delta.function.arguments = arguments
|
||||
return delta
|
||||
|
||||
|
||||
def _make_placement() -> Placement:
|
||||
return Placement(turn_index=0, tab_index=0)
|
||||
|
||||
|
||||
def _mock_tool_class(emit: bool = True) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
cls.should_emit_argument_deltas.return_value = emit
|
||||
return cls
|
||||
|
||||
|
||||
def _collect(
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
delta: MagicMock,
|
||||
placement: Placement | None = None,
|
||||
parsers: dict[int, Parser] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Run maybe_emit_argument_delta and return the yielded packets."""
|
||||
return list(
|
||||
maybe_emit_argument_delta(
|
||||
tc_map,
|
||||
delta,
|
||||
placement or _make_placement(),
|
||||
parsers if parsers is not None else {},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _stream_fragments(
|
||||
fragments: list[str],
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
placement: Placement | None = None,
|
||||
) -> list[str]:
|
||||
"""Feed fragments into maybe_emit_argument_delta one by one, returning
|
||||
all emitted content values concatenated per-key as a flat list."""
|
||||
pl = placement or _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
emitted: list[str] = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
delta = _make_tool_call_delta(arguments=frag)
|
||||
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
|
||||
obj = packet.obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
for value in obj.argument_deltas.values():
|
||||
emitted.append(value)
|
||||
return emitted
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaGuards:
|
||||
"""Tests for conditions that cause no packet to be emitted."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_does_not_opt_in(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Tools that return False from should_emit_argument_deltas emit nothing."""
|
||||
mock_get_tool.return_value = _mock_tool_class(emit=False)
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_class_unknown(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = None
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_no_argument_fragment(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_key_value_not_started(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Key exists in JSON but its string value hasn't begun yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Only the opening brace has arrived — no key to stream yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": "{"}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaBasic:
|
||||
"""Tests for correct packet content and incremental emission."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "', "print(1)", '"}']
|
||||
|
||||
pl = _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
all_packets = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
|
||||
)
|
||||
all_packets.extend(packets)
|
||||
|
||||
assert len(all_packets) >= 1
|
||||
# Verify packet structure
|
||||
obj = all_packets[0].obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
assert obj.tool_type == "python"
|
||||
# All emitted content should reconstruct the value
|
||||
full_code = ""
|
||||
for p in all_packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
if "code" in p.obj.argument_deltas:
|
||||
full_code += p.obj.argument_deltas["code"]
|
||||
assert full_code == "print(1)"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_only_new_content_on_subsequent_call(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""After a first emission, subsequent calls emit only the diff."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# First fragment opens the string
|
||||
tc_map[0]["arguments"] = '{"code": "abc'
|
||||
packets_1 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "abc"
|
||||
|
||||
# Second fragment appends more
|
||||
tc_map[0]["arguments"] = '{"code": "abcdef'
|
||||
packets_2 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
|
||||
)
|
||||
code_2 = ""
|
||||
for p in packets_2:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_2 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_2 == "def"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""When a second key starts, emissions switch to that key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'", "output": "hello',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "x" in full
|
||||
assert "hello" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains the end of one value and the start of the next key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'y", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "xy" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An empty string value has nothing to emit."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# Opening quote just arrived, value is empty
|
||||
tc_map[0]["arguments"] = '{"code": "'
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
|
||||
# No string content yet, so either no packet or empty deltas
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
assert p.obj.argument_deltas.get("code", "") == ""
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaDecoding:
|
||||
"""Tests verifying that JSON escape sequences are properly decoded."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "line1\\nline2"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "line1\nline2"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\tindented"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "\tindented"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "say \\"hi\\""}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == 'say "hi"'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "path\\\\dir"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "path\\dir"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\u0041"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "A"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_escape_at_end_decoded_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\', 'n"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "hello\n"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_unicode_escape_completed_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A partial \\uXX sequence is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\u00', '41"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "helloA"
|
||||
|
||||
|
||||
class TestArgumentDeltaStreamingE2E:
|
||||
"""Simulates realistic sequences of LLM argument deltas to verify
|
||||
the full pipeline produces correct decoded output."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"',
|
||||
"code",
|
||||
'": "',
|
||||
"print(",
|
||||
"'hello')",
|
||||
"\\n",
|
||||
"print(",
|
||||
"'world')",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "print('hello')\nprint('world')"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code with tabs and newlines."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"if True:",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"pass",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "if True:\n\tpass"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An escape sequence split across two fragments (backslash in one,
|
||||
'n' in the next) should still decode correctly."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "hello',
|
||||
"\\",
|
||||
"n",
|
||||
'world"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "hello\nworld"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams a multi-line function with multiple escape sequences."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"def foo():",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"x = 1",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"return x",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "def foo():\n\tx = 1\n\treturn x"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code first, then a second key (language) — both decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = 1",
|
||||
'", "language": "',
|
||||
"python",
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
# Should have emissions for both keys
|
||||
full = "".join(emitted)
|
||||
assert "x = 1" in full
|
||||
assert "python" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
|
||||
The escaped quotes inside the *outer* JSON value should prevent the
|
||||
inner `"key":` from being mistaken for a top-level JSON key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
|
||||
# The inner quotes are escaped as \" in the JSON value.
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = {",
|
||||
'\\"key\\"',
|
||||
": ",
|
||||
'\\"val\\"',
|
||||
"}",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'x = {"key": "val"}'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Colons inside the string value should not confuse key detection."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"url = ",
|
||||
'\\"https://example.com\\"',
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'url = "https://example.com"'
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaEdgeCases:
|
||||
"""Edge cases not covered by the standard test classes."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Some delta chunks have function=None (e.g. role-only deltas)."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
|
||||
assert _collect(tc_map, delta) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Two tool calls streaming at different indices in parallel."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""},
|
||||
1: {"id": "tc_2", "name": "python", "arguments": ""},
|
||||
}
|
||||
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# Feed full JSON to index 0
|
||||
tc_map[0]["arguments"] = '{"code": "aaa"}'
|
||||
packets_0 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_0 = ""
|
||||
for p in packets_0:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_0 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_0 == "aaa"
|
||||
|
||||
# Feed full JSON to index 1
|
||||
tc_map[1]["arguments"] = '{"code": "bbb"}'
|
||||
packets_1 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "bbb"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains four complete key-value pairs."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
tc_map[0]["arguments"] = full
|
||||
parsers: dict[int, Parser] = {}
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
|
||||
)
|
||||
|
||||
# Collect all argument deltas across packets
|
||||
all_deltas: dict[str, str] = {}
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
for k, v in p.obj.argument_deltas.items():
|
||||
all_deltas[k] = all_deltas.get(k, "") + v
|
||||
|
||||
assert all_deltas == {
|
||||
"a": "one",
|
||||
"b": "two",
|
||||
"c": "three",
|
||||
"d": "four",
|
||||
}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_on_second_arg_after_first_complete(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""First argument is fully complete; delta only adds to the second."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
|
||||
fragments = [
|
||||
'{"code": "print(1)", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "print(1)" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Non-string values (numbers, booleans, null) are skipped — they are
|
||||
available in the final tool-call kickoff packet. String arguments
|
||||
following them are still emitted."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"timeout": 30, "code": "hello"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert full == "hello"
|
||||
@@ -9,8 +9,6 @@ from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("enable_ee")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jira_cc_pair(
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.voice.api import _validate_voice_api_base
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
|
||||
with pytest.raises(HTTPException, match="Invalid target URI"):
|
||||
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_allows_private_for_azure() -> None:
|
||||
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
|
||||
assert validated == "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
|
||||
with pytest.raises(HTTPException, match="Invalid target URI"):
|
||||
_validate_voice_api_base("azure", "http://metadata.google.internal/")
|
||||
@@ -1,188 +0,0 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
raise OSError("tell not supported")
|
||||
|
||||
def seek(self, *_args: object, **_kwargs: object) -> int:
|
||||
raise OSError("seek not supported")
|
||||
|
||||
|
||||
def _make_upload(filename: str, size: int, content: bytes | None = None) -> UploadFile:
|
||||
payload = content if content is not None else (b"x" * size)
|
||||
return UploadFile(filename=filename, file=BytesIO(payload), size=size)
|
||||
|
||||
|
||||
def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
upload = UploadFile(filename="example.txt", file=BytesIO(b"abcdef"), size=None)
|
||||
upload.file.seek(2)
|
||||
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size == 6
|
||||
assert upload.file.tell() == 2
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_logs_warning_when_stream_size_unavailable(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = UploadFile(filename="non_seekable.txt", file=_NonSeekableFile(), size=None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size is None
|
||||
assert "Could not determine upload size via stream seek" in caplog.text
|
||||
assert "non_seekable.txt" in caplog.text
|
||||
|
||||
|
||||
def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = _make_upload("size_unknown.txt", size=1)
|
||||
monkeypatch.setattr(utils, "get_upload_size_bytes", lambda _upload: None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
is_too_large = utils.is_upload_too_large(upload, max_bytes=100)
|
||||
|
||||
assert is_too_large is False
|
||||
assert "Could not determine upload size; skipping size-limit check" in caplog.text
|
||||
assert "size_unknown.txt" in caplog.text
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
assert [file.filename for file in result.acceptable] == ["small.png"]
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "large.png"
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
@@ -1,32 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
def __init__(self) -> None:
|
||||
self._vals: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._vals.get(key)
|
||||
|
||||
def set(self, key: str, value: str, ex: int | None = None) -> None: # noqa: ARG002
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
@@ -1,394 +0,0 @@
|
||||
"""Tests for the jsonriver incremental JSON parser."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.jsonriver import JsonValue
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _all_deltas(chunks: list[str]) -> list[JsonValue]:
|
||||
"""Feed chunks one at a time and collect all emitted deltas."""
|
||||
parser = Parser()
|
||||
deltas: list[JsonValue] = []
|
||||
for chunk in chunks:
|
||||
deltas.extend(parser.feed(chunk))
|
||||
deltas.extend(parser.finish())
|
||||
return deltas
|
||||
|
||||
|
||||
class TestParseComplete:
|
||||
"""Parsing complete JSON in a single chunk."""
|
||||
|
||||
def test_simple_object(self) -> None:
|
||||
deltas = _all_deltas(['{"a": 1}'])
|
||||
assert any(r == {"a": 1.0} or r == {"a": 1} for r in deltas)
|
||||
|
||||
def test_simple_array(self) -> None:
|
||||
deltas = _all_deltas(["[1, 2, 3]"])
|
||||
assert any(isinstance(r, list) for r in deltas)
|
||||
|
||||
def test_simple_string(self) -> None:
|
||||
deltas = _all_deltas(['"hello"'])
|
||||
assert "hello" in deltas or any("hello" in str(r) for r in deltas)
|
||||
|
||||
def test_null(self) -> None:
|
||||
deltas = _all_deltas(["null"])
|
||||
assert None in deltas
|
||||
|
||||
def test_boolean_true(self) -> None:
|
||||
deltas = _all_deltas(["true"])
|
||||
assert True in deltas
|
||||
|
||||
def test_boolean_false(self) -> None:
|
||||
deltas = _all_deltas(["false"])
|
||||
assert any(r is False for r in deltas)
|
||||
|
||||
def test_number(self) -> None:
|
||||
deltas = _all_deltas(["42"])
|
||||
assert 42.0 in deltas
|
||||
|
||||
def test_negative_number(self) -> None:
|
||||
deltas = _all_deltas(["-3.14"])
|
||||
assert any(abs(r - (-3.14)) < 1e-10 for r in deltas if isinstance(r, float))
|
||||
|
||||
def test_empty_object(self) -> None:
|
||||
deltas = _all_deltas(["{}"])
|
||||
assert {} in deltas
|
||||
|
||||
def test_empty_array(self) -> None:
|
||||
deltas = _all_deltas(["[]"])
|
||||
assert [] in deltas
|
||||
|
||||
|
||||
class TestStreamingDeltas:
|
||||
"""Incremental feeding produces correct deltas."""
|
||||
|
||||
def test_object_string_value_streamed_char_by_char(self) -> None:
|
||||
chunks = list('{"code": "abc"}')
|
||||
deltas = _all_deltas(chunks)
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "code" in d:
|
||||
val = d["code"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "abc"
|
||||
|
||||
def test_object_streamed_in_two_halves(self) -> None:
|
||||
deltas = _all_deltas(['{"name": "Al', 'ice"}'])
|
||||
str_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "name" in d:
|
||||
val = d["name"]
|
||||
if isinstance(val, str):
|
||||
str_parts.append(val)
|
||||
assert "".join(str_parts) == "Alice"
|
||||
|
||||
def test_multiple_keys_streamed(self) -> None:
|
||||
deltas = _all_deltas(['{"a": "x', '", "b": "y"}'])
|
||||
a_parts: list[str] = []
|
||||
b_parts: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "a" in d and isinstance(d["a"], str):
|
||||
a_parts.append(d["a"])
|
||||
if "b" in d and isinstance(d["b"], str):
|
||||
b_parts.append(d["b"])
|
||||
assert "".join(a_parts) == "x"
|
||||
assert "".join(b_parts) == "y"
|
||||
|
||||
def test_deltas_only_contain_new_string_content(self) -> None:
|
||||
parser = Parser()
|
||||
d1 = parser.feed('{"msg": "hel')
|
||||
d2 = parser.feed('lo"}')
|
||||
parser.finish()
|
||||
|
||||
msg_parts = []
|
||||
for d in d1 + d2:
|
||||
if isinstance(d, dict) and "msg" in d:
|
||||
val = d["msg"]
|
||||
if isinstance(val, str):
|
||||
msg_parts.append(val)
|
||||
assert "".join(msg_parts) == "hello"
|
||||
|
||||
# Each delta should only contain new chars, not repeat previous ones
|
||||
if len(msg_parts) == 2:
|
||||
assert msg_parts[0] == "hel"
|
||||
assert msg_parts[1] == "lo"
|
||||
|
||||
|
||||
class TestEscapeSequences:
|
||||
"""JSON escape sequences are decoded correctly, even across chunk boundaries."""
|
||||
|
||||
def test_newline_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"text": "line1\\nline2"}'])
|
||||
text_parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "text" in d and isinstance(d["text"], str):
|
||||
text_parts.append(d["text"])
|
||||
assert "".join(text_parts) == "line1\nline2"
|
||||
|
||||
def test_tab_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"t": "a\\tb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "t" in d and isinstance(d["t"], str):
|
||||
parts.append(d["t"])
|
||||
assert "".join(parts) == "a\tb"
|
||||
|
||||
def test_escaped_quote(self) -> None:
|
||||
deltas = _all_deltas(['{"q": "say \\"hi\\""}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "q" in d and isinstance(d["q"], str):
|
||||
parts.append(d["q"])
|
||||
assert "".join(parts) == 'say "hi"'
|
||||
|
||||
def test_unicode_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u0041\\u0042"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "AB"
|
||||
|
||||
def test_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"x": "a\\', 'nb"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "x" in d and isinstance(d["x"], str):
|
||||
parts.append(d["x"])
|
||||
assert "".join(parts) == "a\nb"
|
||||
|
||||
def test_unicode_escape_split_across_chunks(self) -> None:
|
||||
deltas = _all_deltas(['{"u": "\\u00', '41"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
|
||||
parts.append(d["u"])
|
||||
assert "".join(parts) == "A"
|
||||
|
||||
def test_backslash_escape(self) -> None:
|
||||
deltas = _all_deltas(['{"p": "c:\\\\dir"}'])
|
||||
parts = []
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "p" in d and isinstance(d["p"], str):
|
||||
parts.append(d["p"])
|
||||
assert "".join(parts) == "c:\\dir"
|
||||
|
||||
|
||||
class TestNestedStructures:
|
||||
"""Nested objects and arrays produce correct deltas."""
|
||||
|
||||
def test_nested_object(self) -> None:
|
||||
deltas = _all_deltas(['{"outer": {"inner": "val"}}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "outer" in d:
|
||||
outer = d["outer"]
|
||||
if isinstance(outer, dict) and "inner" in outer:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
def test_array_of_strings(self) -> None:
|
||||
deltas = _all_deltas(['["a', '", "b"]'])
|
||||
all_items: list[str] = []
|
||||
for d in deltas:
|
||||
if isinstance(d, list):
|
||||
for item in d:
|
||||
if isinstance(item, str):
|
||||
all_items.append(item)
|
||||
elif isinstance(d, str):
|
||||
all_items.append(d)
|
||||
joined = "".join(all_items)
|
||||
assert "a" in joined
|
||||
assert "b" in joined
|
||||
|
||||
def test_object_with_number_and_bool(self) -> None:
|
||||
deltas = _all_deltas(['{"count": 42, "active": true}'])
|
||||
has_count = False
|
||||
has_active = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict):
|
||||
if "count" in d and d["count"] == 42.0:
|
||||
has_count = True
|
||||
if "active" in d and d["active"] is True:
|
||||
has_active = True
|
||||
assert has_count
|
||||
assert has_active
|
||||
|
||||
def test_object_with_null_value(self) -> None:
|
||||
deltas = _all_deltas(['{"key": null}'])
|
||||
found = False
|
||||
for d in deltas:
|
||||
if isinstance(d, dict) and "key" in d and d["key"] is None:
|
||||
found = True
|
||||
assert found
|
||||
|
||||
|
||||
class TestComputeDelta:
|
||||
"""Direct tests for the _compute_delta static method."""
|
||||
|
||||
def test_none_prev_returns_current(self) -> None:
|
||||
assert Parser._compute_delta(None, {"a": "b"}) == {"a": "b"}
|
||||
|
||||
def test_string_delta(self) -> None:
|
||||
assert Parser._compute_delta("hel", "hello") == "lo"
|
||||
|
||||
def test_string_no_change(self) -> None:
|
||||
assert Parser._compute_delta("same", "same") is None
|
||||
|
||||
def test_dict_new_key(self) -> None:
|
||||
assert Parser._compute_delta({"a": "x"}, {"a": "x", "b": "y"}) == {"b": "y"}
|
||||
|
||||
def test_dict_string_append(self) -> None:
|
||||
assert Parser._compute_delta({"code": "def"}, {"code": "def hello()"}) == {
|
||||
"code": " hello()"
|
||||
}
|
||||
|
||||
def test_dict_no_change(self) -> None:
|
||||
assert Parser._compute_delta({"a": 1}, {"a": 1}) is None
|
||||
|
||||
def test_list_new_items(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2, 3]) == [3]
|
||||
|
||||
def test_list_last_item_updated(self) -> None:
|
||||
assert Parser._compute_delta(["a"], ["ab"]) == ["ab"]
|
||||
|
||||
def test_list_no_change(self) -> None:
|
||||
assert Parser._compute_delta([1, 2], [1, 2]) is None
|
||||
|
||||
def test_primitive_change(self) -> None:
|
||||
assert Parser._compute_delta(1, 2) == 2
|
||||
|
||||
def test_primitive_no_change(self) -> None:
|
||||
assert Parser._compute_delta(42, 42) is None
|
||||
|
||||
|
||||
class TestParserLifecycle:
|
||||
"""Edge cases around parser state and lifecycle."""
|
||||
|
||||
def test_feed_after_finish_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
parser.feed('{"a": 1}')
|
||||
parser.finish()
|
||||
assert parser.feed("more") == []
|
||||
|
||||
def test_empty_feed_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed("") == []
|
||||
|
||||
def test_whitespace_only_returns_empty(self) -> None:
|
||||
parser = Parser()
|
||||
assert parser.feed(" ") == []
|
||||
|
||||
def test_finish_with_trailing_whitespace(self) -> None:
|
||||
parser = Parser()
|
||||
# Trailing whitespace terminates the number, so feed() emits it
|
||||
deltas = parser.feed("42 ")
|
||||
assert 42.0 in deltas
|
||||
parser.finish() # Should not raise
|
||||
|
||||
def test_finish_with_trailing_content_raises(self) -> None:
|
||||
parser = Parser()
|
||||
# Feed a complete JSON value followed by non-whitespace in one chunk
|
||||
parser.feed('{"a": 1} extra')
|
||||
with pytest.raises(ValueError, match="Unexpected trailing"):
|
||||
parser.finish()
|
||||
|
||||
def test_finish_flushes_pending_number(self) -> None:
|
||||
parser = Parser()
|
||||
deltas = parser.feed("42")
|
||||
# Number has no terminator, so feed() can't emit it yet
|
||||
assert deltas == []
|
||||
final = parser.finish()
|
||||
assert 42.0 in final
|
||||
|
||||
|
||||
class TestToolCallSimulation:
|
||||
"""Simulate the LLM tool-call streaming use case."""
|
||||
|
||||
def test_python_tool_call_streaming(self) -> None:
|
||||
full_json = json.dumps({"code": "print('hello world')"})
|
||||
chunk_size = 5
|
||||
chunks = [
|
||||
full_json[i : i + chunk_size] for i in range(0, len(full_json), chunk_size)
|
||||
]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == "print('hello world')"
|
||||
|
||||
def test_multi_arg_tool_call(self) -> None:
|
||||
full = '{"query": "search term", "num_results": 5}'
|
||||
chunks = [full[:15], full[15:30], full[30:]]
|
||||
|
||||
parser = Parser()
|
||||
query_parts: list[str] = []
|
||||
has_num_results = False
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict):
|
||||
if "query" in delta and isinstance(delta["query"], str):
|
||||
query_parts.append(delta["query"])
|
||||
if "num_results" in delta:
|
||||
has_num_results = True
|
||||
assert "".join(query_parts) == "search term"
|
||||
assert has_num_results
|
||||
|
||||
def test_code_with_newlines_and_escapes(self) -> None:
|
||||
code = 'def greet(name):\n print(f"Hello, {name}!")\n return True'
|
||||
full = json.dumps({"code": code})
|
||||
chunk_size = 8
|
||||
chunks = [full[i : i + chunk_size] for i in range(0, len(full), chunk_size)]
|
||||
|
||||
parser = Parser()
|
||||
code_parts: list[str] = []
|
||||
for chunk in chunks:
|
||||
for delta in parser.feed(chunk):
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "code" in delta:
|
||||
val = delta["code"]
|
||||
if isinstance(val, str):
|
||||
code_parts.append(val)
|
||||
assert "".join(code_parts) == code
|
||||
|
||||
def test_single_char_streaming(self) -> None:
|
||||
full = '{"key": "value"}'
|
||||
parser = Parser()
|
||||
key_parts: list[str] = []
|
||||
for ch in full:
|
||||
for delta in parser.feed(ch):
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
for delta in parser.finish():
|
||||
if isinstance(delta, dict) and "key" in delta:
|
||||
val = delta["key"]
|
||||
if isinstance(val, str):
|
||||
key_parts.append(val)
|
||||
assert "".join(key_parts) == "value"
|
||||
@@ -147,18 +147,15 @@ class TestSensitiveValueString:
|
||||
)
|
||||
assert sensitive1 != sensitive2
|
||||
|
||||
def test_equality_with_non_sensitive_returns_not_equal(self) -> None:
|
||||
"""Test that comparing with non-SensitiveValue is always not-equal.
|
||||
|
||||
Returns NotImplemented so Python falls back to identity comparison.
|
||||
This is required for compatibility with SQLAlchemy's attribute tracking.
|
||||
"""
|
||||
def test_equality_with_non_sensitive_raises(self) -> None:
|
||||
"""Test that comparing with non-SensitiveValue raises error."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert not (sensitive == "secret")
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
_ = sensitive == "secret"
|
||||
|
||||
|
||||
class TestSensitiveValueJson:
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.utils.url import _is_ip_private_or_reserved
|
||||
from onyx.utils.url import _validate_and_resolve_url
|
||||
from onyx.utils.url import ssrf_safe_get
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
|
||||
class TestIsIpPrivateOrReserved:
|
||||
@@ -306,22 +305,3 @@ class TestSsrfSafeGet:
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
|
||||
class TestValidateOutboundHttpUrl:
|
||||
def test_rejects_private_ip_by_default(self) -> None:
|
||||
with pytest.raises(SSRFException, match="internal/private IP"):
|
||||
validate_outbound_http_url("http://127.0.0.1:8000")
|
||||
|
||||
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
|
||||
validated_url = validate_outbound_http_url(
|
||||
"http://127.0.0.1:8000", allow_private_network=True
|
||||
)
|
||||
assert validated_url == "http://127.0.0.1:8000"
|
||||
|
||||
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
|
||||
with pytest.raises(SSRFException, match="not allowed"):
|
||||
validate_outbound_http_url(
|
||||
"http://metadata.google.internal/latest",
|
||||
allow_private_network=True,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,8 @@ from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# Import the function under test
|
||||
|
||||
|
||||
class TestBuildVespaFilters:
|
||||
def test_empty_filters(self) -> None:
|
||||
@@ -177,27 +179,11 @@ class TestBuildVespaFilters:
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_project_filter(self) -> None:
|
||||
"""Test user project filtering.
|
||||
|
||||
project_id alone does NOT trigger a knowledge scope restriction
|
||||
(an agent with no explicit knowledge should search everything).
|
||||
It only participates when explicit knowledge filters are present.
|
||||
"""
|
||||
# project_id alone → no restriction
|
||||
"""Test user project filtering (replacement for user folder IDs)."""
|
||||
# Single project id
|
||||
filters = IndexFilters(access_control_list=[], project_id=789)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# project_id with user_file_ids → both OR'd
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], project_id=789, user_file_ids=[id1]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
|
||||
== result
|
||||
)
|
||||
assert f'!({HIDDEN}=true) and ({USER_PROJECT} contains "789") and ' == result
|
||||
|
||||
# No project id
|
||||
filters = IndexFilters(access_control_list=[], project_id=None)
|
||||
@@ -231,11 +217,7 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
|
||||
def test_combined_filters(self) -> None:
|
||||
"""Test combining multiple filter types.
|
||||
|
||||
Knowledge-scope filters (document_set, user_file_ids, project_id,
|
||||
persona_id) are OR'd together, while all other filters are AND'd.
|
||||
"""
|
||||
"""Test combining multiple filter types."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=["user1", "group1"],
|
||||
@@ -249,6 +231,7 @@ class TestBuildVespaFilters:
|
||||
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
# Build expected result piece by piece for readability
|
||||
expected = f"!({HIDDEN}=true) and "
|
||||
expected += (
|
||||
'(access_control_list contains "user1" or '
|
||||
@@ -256,13 +239,9 @@ class TestBuildVespaFilters:
|
||||
)
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
# Knowledge scope filters are OR'd together
|
||||
expected += (
|
||||
f'(({DOCUMENT_SETS} contains "set1")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f' or ({USER_PROJECT} contains "789")'
|
||||
f") and "
|
||||
)
|
||||
expected += f'({DOCUMENT_SETS} contains "set1") and '
|
||||
expected += f'({DOCUMENT_ID} contains "{str(id1)}") and '
|
||||
expected += f'({USER_PROJECT} contains "789") and '
|
||||
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
|
||||
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
|
||||
@@ -272,32 +251,6 @@ class TestBuildVespaFilters:
|
||||
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
assert expected[:-5] == result_no_trailing # Remove trailing " and "
|
||||
|
||||
def test_knowledge_scope_single_filter_not_wrapped(self) -> None:
|
||||
"""When only one knowledge-scope filter is present it should not
|
||||
be wrapped in an extra OR group."""
|
||||
filters = IndexFilters(access_control_list=[], document_set=["set1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
|
||||
|
||||
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
|
||||
"""Document set filter and user file IDs must be OR'd so that
|
||||
connector documents (in the set) and user files (with specific
|
||||
IDs) can both be found."""
|
||||
id1 = UUID("00000000-0000-0000-0000-000000000123")
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
document_set=["engineering"],
|
||||
user_file_ids=[id1],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = (
|
||||
f"!({HIDDEN}=true) and "
|
||||
f'(({DOCUMENT_SETS} contains "engineering")'
|
||||
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
|
||||
f") and "
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
def test_empty_or_none_values(self) -> None:
|
||||
"""Test with empty or None values in filter lists."""
|
||||
# Empty strings in document set
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
def test_azure_provider_extracts_region_from_target_uri() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://westus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider.speech_region == "westus"
|
||||
|
||||
|
||||
def test_azure_provider_normalizes_uppercase_region() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "WestUS2"},
|
||||
)
|
||||
assert provider.speech_region == "westus2"
|
||||
|
||||
|
||||
def test_azure_provider_rejects_invalid_speech_region() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "westus/../../etc"},
|
||||
)
|
||||
@@ -1,22 +0,0 @@
|
||||
FROM golang:1.26-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY ./ .
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o onyx-cli .
|
||||
RUN mkdir -p /home/onyx/.config
|
||||
|
||||
FROM scratch
|
||||
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=builder --chown=65534:65534 /home/onyx /home/onyx
|
||||
|
||||
COPY --from=builder /app/onyx-cli /onyx-cli
|
||||
|
||||
ENV HOME=/home/onyx
|
||||
ENV XDG_CONFIG_HOME=/home/onyx/.config
|
||||
|
||||
USER 65534:65534
|
||||
|
||||
ENTRYPOINT ["/onyx-cli"]
|
||||
@@ -1,8 +1,5 @@
|
||||
# Onyx CLI
|
||||
|
||||
[](https://github.com/onyx-dot-app/onyx/actions/workflows/release-cli.yml)
|
||||
[](https://pypi.org/project/onyx-cli/)
|
||||
|
||||
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
|
||||
|
||||
## Installation
|
||||
@@ -31,7 +28,7 @@ Environment variables override config file values:
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
|
||||
@@ -1261,5 +1261,3 @@ configMap:
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
@@ -18,10 +18,6 @@ variable "INTEGRATION_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-integration"
|
||||
}
|
||||
|
||||
variable "CLI_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-cli"
|
||||
}
|
||||
|
||||
variable "TAG" {
|
||||
default = "latest"
|
||||
}
|
||||
@@ -68,13 +64,3 @@ target "integration" {
|
||||
|
||||
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "cli" {
|
||||
context = "cli"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${CLI_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ backend = [
|
||||
"alembic==1.10.4",
|
||||
"asyncpg==0.30.0",
|
||||
"atlassian-python-api==3.41.16",
|
||||
"azure-cognitiveservices-speech==1.38.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"boto3==1.39.11",
|
||||
"boto3-stubs[s3]==1.39.11",
|
||||
|
||||
@@ -26,16 +26,20 @@ class CustomBuildHook(BuildHookInterface):
|
||||
|
||||
# Get config and environment
|
||||
binary_name = self.config["binary_name"]
|
||||
tag_prefix = self.config.get("tag_prefix", binary_name)
|
||||
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{tag_prefix}/")
|
||||
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{binary_name}/")
|
||||
commit = os.getenv("GITHUB_SHA", "none")
|
||||
|
||||
# Build the Go binary if it doesn't exist
|
||||
if not os.path.exists(binary_name):
|
||||
print(f"Building Go binary '{binary_name}'...")
|
||||
ldflags = f"-X main.version={tag} -X main.commit={commit} -s -w"
|
||||
subprocess.check_call( # noqa: S603
|
||||
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
|
||||
[
|
||||
"go",
|
||||
"build",
|
||||
f"-ldflags=-X main.version={tag} -X main.commit={commit} -s -w",
|
||||
"-o",
|
||||
binary_name,
|
||||
],
|
||||
)
|
||||
|
||||
build_data["shared_scripts"] = {binary_name: binary_name}
|
||||
|
||||
@@ -3,9 +3,6 @@ from __future__ import annotations
|
||||
import os
|
||||
import re
|
||||
|
||||
# Must match tag_prefix in pyproject.toml [tool.hatch.build.targets.wheel.hooks.custom]
|
||||
TAG_PREFIX: str = "ods"
|
||||
|
||||
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix(f"{TAG_PREFIX}/")
|
||||
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix("ods/")
|
||||
_match = re.search(r"v?(\d+\.\d+\.\d+)", _tag)
|
||||
__version__ = _match.group(1) if _match else "0.0.0"
|
||||
|
||||
@@ -14,9 +14,7 @@ keywords = [
|
||||
classifiers = [
|
||||
"Programming Language :: Go",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
dependencies = [
|
||||
@@ -29,7 +27,7 @@ dependencies = [
|
||||
Repository = "https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
[tool.hatch.build]
|
||||
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py", "README.md"]
|
||||
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py"]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "code"
|
||||
@@ -38,7 +36,6 @@ path = "internal/_version.py"
|
||||
[tool.hatch.build.targets.wheel.hooks.custom]
|
||||
path = "hatch_build.py"
|
||||
binary_name = "ods"
|
||||
tag_prefix = "ods"
|
||||
|
||||
[tool.uv]
|
||||
managed = false
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user