mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-13 11:42:40 +00:00
Compare commits
4 Commits
voice-mode
...
xlsx-parse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83558ae04c | ||
|
|
005009602c | ||
|
|
b93875353b | ||
|
|
2290141b53 |
60
.github/workflows/deployment.yml
vendored
60
.github/workflows/deployment.yml
vendored
@@ -29,32 +29,20 @@ jobs:
|
||||
build-backend-craft: ${{ steps.check.outputs.build-backend-craft }}
|
||||
build-model-server: ${{ steps.check.outputs.build-model-server }}
|
||||
is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }}
|
||||
is-stable: ${{ steps.check.outputs.is-stable }}
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-latest: ${{ steps.check.outputs.is-latest }}
|
||||
is-craft-latest: ${{ steps.check.outputs.is-craft-latest }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Checkout (for git tags)
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
enable-cache: false
|
||||
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
set -eo pipefail
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
@@ -66,8 +54,9 @@ jobs:
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_LATEST=false
|
||||
IS_CRAFT_LATEST=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
@@ -78,6 +67,9 @@ jobs:
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == craft-* ]]; then
|
||||
IS_CRAFT_LATEST=true
|
||||
fi
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
@@ -105,28 +97,20 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
# Craft-latest builds backend with Craft enabled
|
||||
if [[ "$IS_CRAFT_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
BUILD_BACKEND=false
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this tag should get the "latest" Docker tag.
|
||||
# Only the highest semver stable tag (vX.Y.Z exactly) gets "latest".
|
||||
if [[ "$IS_STABLE" == "true" ]]; then
|
||||
HIGHEST_STABLE=$(uv run --no-sync --with onyx-devtools ods latest-stable-tag) || {
|
||||
echo "::error::Failed to determine highest stable tag via 'ods latest-stable-tag'"
|
||||
exit 1
|
||||
}
|
||||
if [[ "$TAG" == "$HIGHEST_STABLE" ]]; then
|
||||
IS_LATEST=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Build craft-latest backend alongside the regular latest.
|
||||
if [[ "$IS_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
@@ -145,9 +129,11 @@ jobs:
|
||||
echo "build-backend-craft=$BUILD_BACKEND_CRAFT"
|
||||
echo "build-model-server=$BUILD_MODEL_SERVER"
|
||||
echo "is-cloud-tag=$IS_CLOUD"
|
||||
echo "is-stable=$IS_STABLE"
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-latest=$IS_LATEST"
|
||||
echo "is-craft-latest=$IS_CRAFT_LATEST"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
@@ -614,7 +600,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1051,7 +1037,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
@@ -1487,7 +1473,7 @@ jobs:
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
|
||||
187
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
187
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
@@ -1,102 +1,67 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types:
|
||||
- closed
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
# SECURITY NOTE:
|
||||
# This workflow intentionally uses pull_request_target so post-merge automation can
|
||||
# use base-repo credentials. Do not checkout PR head refs in this workflow
|
||||
# (e.g. github.event.pull_request.head.sha). Only trusted base refs are allowed.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
resolve-cherry-pick-request:
|
||||
if: >-
|
||||
github.event.pull_request.merged == true
|
||||
&& github.event.pull_request.base.ref == 'main'
|
||||
&& github.event.pull_request.head.repo.full_name == github.repository
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
merge_commit_sha: ${{ steps.gate.outputs.merge_commit_sha }}
|
||||
merged_by: ${{ steps.gate.outputs.merged_by }}
|
||||
gate_error: ${{ steps.gate.outputs.gate_error }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
# SECURITY: keep PR body in env/plain-text handling; avoid directly
|
||||
# inlining github.event.pull_request.body into shell commands.
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }}
|
||||
MERGED_BY: ${{ github.event.pull_request.merged_by.login }}
|
||||
# GitHub team slug authorized to trigger cherry-picks (e.g. "core-eng").
|
||||
# For private/secret teams the GITHUB_TOKEN may need org:read scope;
|
||||
# visible teams work with the default token.
|
||||
ALLOWED_TEAM: "onyx-core-team"
|
||||
run: |
|
||||
echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if ! echo "${PR_BODY}" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${PR_NUMBER}. Skipping."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Keep should_cherrypick output before any possible exit 1 below so
|
||||
# notify-slack can still gate on this output even if this job fails.
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${PR_NUMBER}."
|
||||
|
||||
if [ -z "${MERGE_COMMIT_SHA}" ] || [ "${MERGE_COMMIT_SHA}" = "null" ]; then
|
||||
echo "gate_error=missing-merge-commit-sha" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::PR #${PR_NUMBER} requested cherry-pick, but merge_commit_sha is missing."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
member_state_file="$(mktemp)"
|
||||
member_err_file="$(mktemp)"
|
||||
if ! gh api "orgs/${GITHUB_REPOSITORY_OWNER}/teams/${ALLOWED_TEAM}/memberships/${MERGED_BY}" --jq '.state' >"${member_state_file}" 2>"${member_err_file}"; then
|
||||
api_err="$(tr '\n' ' ' < "${member_err_file}" | sed 's/[[:space:]]\+/ /g' | cut -c1-300)"
|
||||
echo "gate_error=team-api-error" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::Team membership API call failed for ${MERGED_BY} in ${ALLOWED_TEAM}: ${api_err}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
member_state="$(cat "${member_state_file}")"
|
||||
if [ "${member_state}" != "active" ]; then
|
||||
echo "gate_error=not-team-member" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::${MERGED_BY} is not an active member of team ${ALLOWED_TEAM} (state: ${member_state}). Failing cherry-pick gate."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exit 0
|
||||
|
||||
cherry-pick-to-latest-release:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success'
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
# SECURITY: keep checkout pinned to trusted base branch; do not switch to PR head refs.
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
@@ -104,37 +69,31 @@ jobs:
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
set +e
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${MERGE_COMMIT_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
pipe_statuses=("${PIPESTATUS[@]}")
|
||||
exit_code="${pipe_statuses[0]}"
|
||||
tee_exit="${pipe_statuses[1]:-0}"
|
||||
set -e
|
||||
if [ "${tee_exit}" -ne 0 ]; then
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
echo "reason=output-capture-failed" >> "$GITHUB_OUTPUT"
|
||||
echo "::error::tee failed to capture cherry-pick output (exit ${tee_exit}); cannot classify result."
|
||||
exit 1
|
||||
fi
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
@@ -156,7 +115,7 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.run_cherry_pick.outputs.status == 'failure'
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
@@ -165,9 +124,8 @@ jobs:
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- resolve-cherry-pick-request
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure')
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
@@ -176,49 +134,22 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Fail if Slack webhook secret is missing
|
||||
env:
|
||||
CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
run: |
|
||||
if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then
|
||||
echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }}
|
||||
MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }}
|
||||
GATE_ERROR: ${{ needs.resolve-cherry-pick-request.outputs.gate_error }}
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then
|
||||
reason_text="requested cherry-pick but merge commit SHA was missing"
|
||||
elif [ "${GATE_ERROR}" = "team-api-error" ]; then
|
||||
reason_text="team membership lookup failed while validating cherry-pick permissions"
|
||||
elif [ "${GATE_ERROR}" = "not-team-member" ]; then
|
||||
reason_text="merger is not an active member of the allowed team"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then
|
||||
reason_text="failed to capture cherry-pick output for classification"
|
||||
elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
if [ -n "${GATE_ERROR}" ]; then
|
||||
failed_job_label="resolve-cherry-pick-request"
|
||||
else
|
||||
failed_job_label="cherry-pick-to-latest-release"
|
||||
fi
|
||||
failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${MERGE_COMMIT_SHA}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
fi
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
@@ -231,4 +162,4 @@ jobs:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
11
.github/workflows/pr-helm-chart-testing.yml
vendored
11
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -133,7 +133,7 @@ jobs:
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint . --set auth.userauth.values.user_auth_secret=placeholder
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
@@ -194,7 +194,6 @@ jobs:
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=auth.userauth.values.user_auth_secret=test-secret \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
@@ -231,10 +230,6 @@ jobs:
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
if ! kubectl cluster-info >/dev/null 2>&1; then
|
||||
echo "ERROR: Kubernetes cluster is not reachable after install"
|
||||
exit 1
|
||||
fi
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
@@ -244,10 +239,6 @@ jobs:
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
if ! kubectl cluster-info >/dev/null 2>&1; then
|
||||
echo "Skipping failure cleanup: Kubernetes cluster is not reachable"
|
||||
exit 0
|
||||
fi
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
2
.github/workflows/storybook-deploy.yml
vendored
2
.github/workflows/storybook-deploy.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN"
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""add timestamps to user table
|
||||
|
||||
Revision ID: 27fb147a843f
|
||||
Revises: b5c4d7e8f9a1
|
||||
Create Date: 2026-03-08 17:18:40.828644
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "27fb147a843f"
|
||||
down_revision = "b5c4d7e8f9a1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "updated_at")
|
||||
op.drop_column("user", "created_at")
|
||||
@@ -1,117 +0,0 @@
|
||||
"""add_voice_provider_and_user_voice_prefs
|
||||
|
||||
Revision ID: 93a2e195e25c
|
||||
Revises: 27fb147a843f
|
||||
Create Date: 2026-02-23 15:16:39.507304
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93a2e195e25c"
|
||||
down_revision = "27fb147a843f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create voice_provider table
|
||||
op.create_table(
|
||||
"voice_provider",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("name", sa.String(), unique=True, nullable=False),
|
||||
sa.Column("provider_type", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("custom_config", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("stt_model", sa.String(), nullable=True),
|
||||
sa.Column("tts_model", sa.String(), nullable=True),
|
||||
sa.Column("default_voice", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"is_default_stt", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column(
|
||||
"is_default_tts", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Add partial unique indexes to enforce only one default STT/TTS provider
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"voice_provider",
|
||||
["is_default_stt"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_stt") == true(),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"voice_provider",
|
||||
["is_default_tts"],
|
||||
unique=True,
|
||||
postgresql_where=column("is_default_tts") == true(),
|
||||
)
|
||||
|
||||
# Add voice preference columns to user table
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_send",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_auto_playback",
|
||||
sa.Boolean(),
|
||||
default=False,
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"voice_playback_speed",
|
||||
sa.Float(),
|
||||
default=1.0,
|
||||
nullable=False,
|
||||
server_default="1.0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove user voice preference columns
|
||||
op.drop_column("user", "voice_playback_speed")
|
||||
op.drop_column("user", "voice_auto_playback")
|
||||
op.drop_column("user", "voice_auto_send")
|
||||
|
||||
op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider")
|
||||
op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider")
|
||||
|
||||
# Drop voice_provider table
|
||||
op.drop_table("voice_provider")
|
||||
@@ -1,8 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -11,102 +9,107 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
def _get_jira_group_members_email(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -127,26 +130,12 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
@@ -31,8 +29,6 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import status
|
||||
from fastapi import WebSocket
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from fastapi_users import BaseUserManager
|
||||
@@ -59,7 +55,6 @@ from fastapi_users.router.common import ErrorModel
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
|
||||
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import nulls_last
|
||||
@@ -125,12 +120,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
from onyx.error_handling.exceptions import onyx_error_to_json_response
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -1622,102 +1612,6 @@ async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
return user
|
||||
|
||||
|
||||
async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
"""Shared logic: token data dict → User object.
|
||||
|
||||
Args:
|
||||
token_data: Decoded token data containing 'sub' (user ID).
|
||||
|
||||
Returns:
|
||||
User object if found and active, None otherwise.
|
||||
"""
|
||||
user_id = token_data.get("sub")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
user_uuid = uuid.UUID(user_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async with get_async_session_context_manager() as async_db_session:
|
||||
user = await async_db_session.get(User, user_uuid)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
) -> User:
|
||||
"""
|
||||
WebSocket authentication dependency using query parameter.
|
||||
|
||||
Validates the WS token from query param and returns the User.
|
||||
Raises BasicAuthenticationError if authentication fails.
|
||||
|
||||
The token must be obtained from POST /voice/ws-token before connecting.
|
||||
Tokens are single-use and expire after 60 seconds.
|
||||
|
||||
Usage:
|
||||
1. POST /voice/ws-token -> {"token": "xxx"}
|
||||
2. Connect to ws://host/path?token=xxx
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
try:
|
||||
token_data = await retrieve_ws_token_data(token)
|
||||
if token_data is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. Invalid or expired authentication token."
|
||||
)
|
||||
except BasicAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"WS auth: error during token validation: {e}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Authentication verification failed."
|
||||
) from e
|
||||
|
||||
# Get user from token data
|
||||
user = await _get_user_from_token_data(token_data)
|
||||
if user is None:
|
||||
logger.warning(f"WS auth: user not found for id={token_data.get('sub')}")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User not found or inactive."
|
||||
)
|
||||
|
||||
# Apply same checks as HTTP auth (verification, OIDC expiry, role)
|
||||
user = await double_check_user(user)
|
||||
|
||||
# Block LIMITED users (same as current_user)
|
||||
if user.role == UserRole.LIMITED:
|
||||
logger.warning(f"WS auth: user {user.email} has LIMITED role")
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
|
||||
)
|
||||
|
||||
logger.debug(f"WS auth: authenticated {user.email}")
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
# No default seeding available for Onyx MIT
|
||||
return []
|
||||
@@ -1727,7 +1621,6 @@ STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
PKCE_COOKIE_NAME_PREFIX = "fastapiusersoauthpkce"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1748,21 +1641,6 @@ def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def _base64url_encode(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
verifier = secrets.token_urlsafe(64)
|
||||
challenge = _base64url_encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def get_pkce_cookie_name(state: str) -> str:
|
||||
state_hash = hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||
return f"{PKCE_COOKIE_NAME_PREFIX}_{state_hash}"
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1771,7 +1649,6 @@ def create_onyx_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
enable_pkce: bool = False,
|
||||
) -> APIRouter:
|
||||
return get_oauth_router(
|
||||
oauth_client,
|
||||
@@ -1781,7 +1658,6 @@ def create_onyx_oauth_router(
|
||||
redirect_url,
|
||||
associate_by_email,
|
||||
is_verified_by_default,
|
||||
enable_pkce=enable_pkce,
|
||||
)
|
||||
|
||||
|
||||
@@ -1800,7 +1676,6 @@ def get_oauth_router(
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
enable_pkce: bool = False,
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1817,13 +1692,6 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
async def null_access_token_state() -> tuple[OAuth2Token, Optional[str]] | None:
|
||||
return None
|
||||
|
||||
access_token_state_dependency = (
|
||||
oauth2_authorize_callback if not enable_pkce else null_access_token_state
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@@ -1857,26 +1725,13 @@ def get_oauth_router(
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
pkce_cookie: tuple[str, str] | None = None
|
||||
|
||||
if enable_pkce:
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
pkce_cookie_name = get_pkce_cookie_name(state)
|
||||
pkce_cookie = (pkce_cookie_name, code_verifier)
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method="S256",
|
||||
)
|
||||
else:
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
@@ -1884,15 +1739,11 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
def set_oauth_cookie(
|
||||
target_response: Response,
|
||||
*,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
target_response.set_cookie(
|
||||
key=key,
|
||||
value=value,
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
@@ -1900,28 +1751,18 @@ def get_oauth_router(
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response_with_cookies: Response
|
||||
if redirect:
|
||||
response_with_cookies = RedirectResponse(authorization_url, status_code=302)
|
||||
else:
|
||||
response_with_cookies = response
|
||||
|
||||
set_oauth_cookie(
|
||||
response_with_cookies,
|
||||
response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
if pkce_cookie is not None:
|
||||
pkce_cookie_name, code_verifier = pkce_cookie
|
||||
set_oauth_cookie(
|
||||
response_with_cookies,
|
||||
key=pkce_cookie_name,
|
||||
value=code_verifier,
|
||||
)
|
||||
|
||||
if redirect:
|
||||
return response_with_cookies
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@@ -1952,242 +1793,119 @@ def get_oauth_router(
|
||||
)
|
||||
async def callback(
|
||||
request: Request,
|
||||
access_token_state: Tuple[OAuth2Token, Optional[str]] | None = Depends(
|
||||
access_token_state_dependency
|
||||
access_token_state: Tuple[OAuth2Token, str] = Depends(
|
||||
oauth2_authorize_callback
|
||||
),
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
) -> Response:
|
||||
pkce_cookie_name: str | None = None
|
||||
) -> RedirectResponse:
|
||||
token, state = access_token_state
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
)
|
||||
|
||||
def delete_pkce_cookie(response: Response) -> None:
|
||||
if enable_pkce and pkce_cookie_name:
|
||||
response.delete_cookie(
|
||||
key=pkce_cookie_name,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
def build_error_response(exc: OnyxError) -> JSONResponse:
|
||||
log_onyx_error(exc)
|
||||
error_response = onyx_error_to_json_response(exc)
|
||||
delete_pkce_cookie(error_response)
|
||||
return error_response
|
||||
|
||||
def decode_and_validate_state(state_value: str) -> Dict[str, str]:
|
||||
try:
|
||||
state_data = decode_jwt(
|
||||
state_value, state_secret, [STATE_TOKEN_AUDIENCE]
|
||||
)
|
||||
except jwt.DecodeError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
"ACCESS_TOKEN_DECODE_ERROR",
|
||||
),
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
return state_data
|
||||
|
||||
token: OAuth2Token
|
||||
state_data: Dict[str, str]
|
||||
|
||||
# `code`, `state`, and `error` are read directly only in the PKCE path.
|
||||
# In the non-PKCE path, `oauth2_authorize_callback` consumes them.
|
||||
if enable_pkce:
|
||||
if state is not None:
|
||||
pkce_cookie_name = get_pkce_cookie_name(state)
|
||||
|
||||
if error is not None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authorization request failed or was denied",
|
||||
)
|
||||
)
|
||||
if code is None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing authorization code in OAuth callback",
|
||||
)
|
||||
)
|
||||
if state is None:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing state parameter in OAuth callback",
|
||||
)
|
||||
)
|
||||
|
||||
state_value = state
|
||||
|
||||
if redirect_url is not None:
|
||||
callback_redirect_url = redirect_url
|
||||
else:
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
callback_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
|
||||
code_verifier = request.cookies.get(cast(str, pkce_cookie_name))
|
||||
if not code_verifier:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing PKCE verifier cookie in OAuth callback",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
state_data = decode_and_validate_state(state_value)
|
||||
except OnyxError as e:
|
||||
return build_error_response(e)
|
||||
|
||||
try:
|
||||
token = await oauth_client.get_access_token(
|
||||
code, callback_redirect_url, code_verifier
|
||||
)
|
||||
except GetAccessTokenError:
|
||||
return build_error_response(
|
||||
OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authorization code exchange failed",
|
||||
)
|
||||
)
|
||||
else:
|
||||
if access_token_state is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR, "Missing OAuth callback state"
|
||||
)
|
||||
token, callback_state = access_token_state
|
||||
if callback_state is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Missing state parameter in OAuth callback",
|
||||
)
|
||||
state_data = decode_and_validate_state(callback_state)
|
||||
|
||||
async def complete_login_flow(
|
||||
token: OAuth2Token, state_data: Dict[str, str]
|
||||
) -> RedirectResponse:
|
||||
account_id, account_email = await oauth_client.get_id_email(
|
||||
token["access_token"]
|
||||
if account_email is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
|
||||
if account_email is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
|
||||
)
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
if not user.is_active:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
account_email,
|
||||
token.get("expires_at"),
|
||||
token.get("refresh_token"),
|
||||
request,
|
||||
associate_by_email=associate_by_email,
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
except UserAlreadyExists:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
|
||||
)
|
||||
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
|
||||
)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_destination = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(
|
||||
redirect_destination, status_code=302
|
||||
)
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
for header_name, header_value in response.headers.items():
|
||||
# FastAPI can have multiple Set-Cookie headers as a list
|
||||
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
|
||||
for cookie_value in header_value:
|
||||
redirect_response.headers.append(header_name, cookie_value)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
for header_name, header_value in response.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower == "set-cookie":
|
||||
redirect_response.headers.append(header_name, header_value)
|
||||
continue
|
||||
if header_name_lower in {"location", "content-length"}:
|
||||
continue
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
return redirect_response
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
if hasattr(response, "status_code"):
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
if enable_pkce:
|
||||
try:
|
||||
redirect_response = await complete_login_flow(token, state_data)
|
||||
except OnyxError as e:
|
||||
return build_error_response(e)
|
||||
delete_pkce_cookie(redirect_response)
|
||||
return redirect_response
|
||||
|
||||
return await complete_login_flow(token, state_data)
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -11,9 +11,6 @@
|
||||
# lock after its cleanup which happens at most after its soft timeout.
|
||||
|
||||
# Constants corresponding to migrate_documents_from_vespa_to_opensearch_task.
|
||||
from onyx.configs.app_configs import OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
|
||||
|
||||
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes.
|
||||
MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes.
|
||||
# The maximum time the lock can be held for. Will automatically be released
|
||||
@@ -47,7 +44,7 @@ TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
|
||||
@@ -196,10 +196,6 @@ if _OIDC_SCOPE_OVERRIDE:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Enables PKCE for OIDC login flow. Disabled by default to preserve
|
||||
# backwards compatibility for existing OIDC deployments.
|
||||
OIDC_PKCE_ENABLED = os.environ.get("OIDC_PKCE_ENABLED", "").lower() == "true"
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/onyx/configs/saml_config"
|
||||
|
||||
@@ -315,12 +311,6 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
|
||||
== "true"
|
||||
)
|
||||
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
|
||||
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
|
||||
)
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
|
||||
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
|
||||
@@ -33,7 +33,6 @@ from office365.runtime.queries.client_query import ClientQuery # type: ignore[i
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
|
||||
@@ -259,10 +258,6 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
# Track yielded hierarchy nodes by their raw_node_id (URLs) to avoid duplicates
|
||||
seen_hierarchy_node_raw_ids: set[str] = Field(default_factory=set)
|
||||
|
||||
# Track yielded document IDs to avoid processing the same document twice.
|
||||
# The Microsoft Graph delta API can return the same item on multiple pages.
|
||||
seen_document_ids: set[str] = Field(default_factory=set)
|
||||
|
||||
|
||||
class SharepointAuthMethod(Enum):
|
||||
CLIENT_SECRET = "client_secret"
|
||||
@@ -273,15 +268,6 @@ class SizeCapExceeded(Exception):
|
||||
"""Exception raised when the size cap is exceeded."""
|
||||
|
||||
|
||||
def _log_and_raise_for_status(response: requests.Response) -> None:
|
||||
"""Log the response text and raise for status."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except Exception:
|
||||
logger.error(f"HTTP request failed: {response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None:
|
||||
"""Load certificate from .pfx file for MSAL authentication"""
|
||||
try:
|
||||
@@ -358,7 +344,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
"""Determine remote size using HEAD or a range GET probe. Returns None if unknown."""
|
||||
try:
|
||||
head_resp = requests.head(url, timeout=timeout, allow_redirects=True)
|
||||
_log_and_raise_for_status(head_resp)
|
||||
head_resp.raise_for_status()
|
||||
cl = head_resp.headers.get("Content-Length")
|
||||
if cl and cl.isdigit():
|
||||
return int(cl)
|
||||
@@ -373,7 +359,7 @@ def _probe_remote_size(url: str, timeout: int) -> int | None:
|
||||
timeout=timeout,
|
||||
stream=True,
|
||||
) as range_resp:
|
||||
_log_and_raise_for_status(range_resp)
|
||||
range_resp.raise_for_status()
|
||||
cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345"
|
||||
if cr and "/" in cr:
|
||||
total = cr.split("/")[-1]
|
||||
@@ -398,7 +384,7 @@ def _download_with_cap(url: str, timeout: int, cap: int) -> bytes:
|
||||
- Returns the full bytes if the content fits within `cap`.
|
||||
"""
|
||||
with requests.get(url, stream=True, timeout=timeout) as resp:
|
||||
_log_and_raise_for_status(resp)
|
||||
resp.raise_for_status()
|
||||
|
||||
# If the server provides Content-Length, prefer an early decision.
|
||||
cl_header = resp.headers.get("Content-Length")
|
||||
@@ -442,7 +428,7 @@ def _download_via_graph_api(
|
||||
with requests.get(
|
||||
url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS
|
||||
) as resp:
|
||||
_log_and_raise_for_status(resp)
|
||||
resp.raise_for_status()
|
||||
buf = io.BytesIO()
|
||||
for chunk in resp.iter_content(64 * 1024):
|
||||
if not chunk:
|
||||
@@ -1259,14 +1245,7 @@ class SharepointConnector(
|
||||
total_yielded = 0
|
||||
|
||||
while page_url:
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 404:
|
||||
logger.warning(f"Site page not found: {page_url}")
|
||||
break
|
||||
raise
|
||||
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
params = None # nextLink already embeds query params
|
||||
|
||||
for page in data.get("value", []):
|
||||
@@ -1330,7 +1309,7 @@ class SharepointConnector(
|
||||
access_token = self._get_graph_access_token()
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
continue
|
||||
_log_and_raise_for_status(response)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except (requests.ConnectionError, requests.Timeout):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
@@ -1578,7 +1557,6 @@ class SharepointConnector(
|
||||
checkpoint.current_drive_id = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
checkpoint.seen_document_ids.clear()
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
@@ -2159,14 +2137,6 @@ class SharepointConnector(
|
||||
item_count = 0
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
|
||||
if driveitem.id and driveitem.id in checkpoint.seen_document_ids:
|
||||
logger.debug(
|
||||
f"Skipping duplicate document {driveitem.id} "
|
||||
f"({driveitem.name})"
|
||||
)
|
||||
continue
|
||||
|
||||
driveitem_extension = get_file_ext(driveitem.name)
|
||||
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
logger.warning(
|
||||
@@ -2219,13 +2189,11 @@ class SharepointConnector(
|
||||
|
||||
if isinstance(doc_or_failure, Document):
|
||||
if doc_or_failure.sections:
|
||||
checkpoint.seen_document_ids.add(doc_or_failure.id)
|
||||
yield doc_or_failure
|
||||
elif should_yield_if_empty:
|
||||
doc_or_failure.sections = [
|
||||
TextSection(link=driveitem.web_url, text="")
|
||||
]
|
||||
checkpoint.seen_document_ids.add(doc_or_failure.id)
|
||||
yield doc_or_failure
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
@@ -370,9 +369,9 @@ def upsert_llm_provider(
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
models: list[dict],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.).
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
@@ -380,7 +379,7 @@ def sync_model_configurations(
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
@@ -394,20 +393,21 @@ def sync_model_configurations(
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
if model.name not in existing_names:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
supported_flows = [LLMModelFlowType.CHAT]
|
||||
if model.supports_image_input:
|
||||
if model.get("supports_image_input", False):
|
||||
supported_flows.append(LLMModelFlowType.VISION)
|
||||
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
model_name=model.name,
|
||||
model_name=model_name,
|
||||
supported_flows=supported_flows,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.max_input_tokens,
|
||||
display_name=model.display_name,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
|
||||
@@ -163,8 +163,6 @@ class _EncryptedBase(TypeDecorator):
|
||||
|
||||
|
||||
class EncryptedString(_EncryptedBase):
|
||||
# Must redeclare cache_ok in this child class since we explicitly redeclare _is_json
|
||||
cache_ok = True
|
||||
_is_json: bool = False
|
||||
|
||||
def process_bind_param(
|
||||
@@ -191,7 +189,6 @@ class EncryptedString(_EncryptedBase):
|
||||
|
||||
|
||||
class EncryptedJson(_EncryptedBase):
|
||||
cache_ok = True
|
||||
_is_json: bool = True
|
||||
|
||||
def process_bind_param(
|
||||
@@ -339,25 +336,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
@@ -3070,65 +3052,6 @@ class ImageGenerationConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class VoiceProvider(Base):
|
||||
"""Configuration for voice services (STT and TTS)."""
|
||||
|
||||
__tablename__ = "voice_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider_type: Mapped[str] = mapped_column(
|
||||
String
|
||||
) # "openai", "azure", "elevenlabs"
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
custom_config: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Model/voice configuration
|
||||
stt_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "whisper-1"
|
||||
tts_model: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "tts-1", "tts-1-hd"
|
||||
default_voice: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
) # e.g., "alloy", "echo"
|
||||
|
||||
# STT and TTS can use different providers - only one provider per type
|
||||
is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
# Enforce only one default STT provider and one default TTS provider at DB level
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_voice_provider_one_default_stt",
|
||||
"is_default_stt",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_stt == True), # noqa: E712
|
||||
),
|
||||
Index(
|
||||
"ix_voice_provider_one_default_tts",
|
||||
"is_default_tts",
|
||||
unique=True,
|
||||
postgresql_where=(is_default_tts == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -12,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -26,7 +24,6 @@ from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -165,13 +162,7 @@ def _get_accepted_user_where_clause(
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name
|
||||
where_clause.append(
|
||||
or_(
|
||||
email_col.ilike(f"%{email_filter_string}%"),
|
||||
personal_name_col.ilike(f"%{email_filter_string}%"),
|
||||
)
|
||||
)
|
||||
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
|
||||
|
||||
if roles_filter:
|
||||
where_clause.append(User.role.in_(roles_filter))
|
||||
@@ -182,21 +173,6 @@ def _get_accepted_user_where_clause(
|
||||
return where_clause
|
||||
|
||||
|
||||
def get_all_accepted_users(
|
||||
db_session: Session,
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
"""Returns all accepted users without pagination.
|
||||
Uses the same filtering as the paginated endpoint but without
|
||||
search, role, or active filters."""
|
||||
stmt = select(User)
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
include_external=include_external,
|
||||
)
|
||||
stmt = stmt.where(*where_clause).order_by(User.email)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_page_of_filtered_users(
|
||||
db_session: Session,
|
||||
page_size: int,
|
||||
@@ -242,41 +218,6 @@ def get_total_filtered_users_count(
|
||||
return db_session.scalar(total_count_stmt) or 0
|
||||
|
||||
|
||||
def get_user_counts_by_role_and_status(
|
||||
db_session: Session,
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""Returns user counts grouped by role and by active/inactive status.
|
||||
|
||||
Excludes API key users, anonymous users, and no-auth placeholder users.
|
||||
Uses a single query with conditional aggregation.
|
||||
"""
|
||||
base_where = _get_accepted_user_where_clause()
|
||||
role_col = User.__table__.c.role
|
||||
is_active_col = User.__table__.c.is_active
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
role_col,
|
||||
func.count().label("total"),
|
||||
func.sum(case((is_active_col.is_(True), 1), else_=0)).label("active"),
|
||||
func.sum(case((is_active_col.is_(False), 1), else_=0)).label("inactive"),
|
||||
)
|
||||
.where(*base_where)
|
||||
.group_by(role_col)
|
||||
)
|
||||
|
||||
role_counts: dict[str, int] = {}
|
||||
status_counts: dict[str, int] = {"active": 0, "inactive": 0}
|
||||
|
||||
for role_val, total, active, inactive in db_session.execute(stmt).all():
|
||||
key = role_val.value if hasattr(role_val, "value") else str(role_val)
|
||||
role_counts[key] = total
|
||||
status_counts["active"] += active or 0
|
||||
status_counts["inactive"] += inactive or 0
|
||||
|
||||
return {"role_counts": role_counts, "status_counts": status_counts}
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = (
|
||||
db_session.query(User)
|
||||
@@ -353,23 +294,24 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
lower_emails = [email.lower() for email in emails]
|
||||
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
|
||||
|
||||
# Use savepoints (begin_nested) so that a failed insert only rolls back
|
||||
# that single user, not the entire transaction. A plain rollback() would
|
||||
# discard all previously flushed users in the same transaction.
|
||||
# We also avoid add_all() because SQLAlchemy 2.0's insertmanyvalues
|
||||
# batch path hits a UUID sentinel mismatch with server_default columns.
|
||||
new_users: list[User] = []
|
||||
for email in missing_lower_emails:
|
||||
user = _generate_ext_permissioned_user(email=email)
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(user)
|
||||
savepoint.commit()
|
||||
except IntegrityError:
|
||||
savepoint.rollback()
|
||||
if not continue_on_error:
|
||||
raise
|
||||
new_users.append(_generate_ext_permissioned_user(email=email))
|
||||
|
||||
db_session.commit()
|
||||
try:
|
||||
db_session.add_all(new_users)
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
if not continue_on_error:
|
||||
raise
|
||||
for user in new_users:
|
||||
try:
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
continue
|
||||
# Fetch all users again to ensure we have the most up-to-date list
|
||||
all_users, _ = _get_users_by_emails(db_session, lower_emails)
|
||||
return all_users
|
||||
@@ -416,28 +358,3 @@ def delete_user_from_db(
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
remove_user_from_invited_users(user_to_delete.email)
|
||||
|
||||
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
UserGroup.name,
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
result[user_id].append((group_id, group_name))
|
||||
return result
|
||||
|
||||
@@ -1,248 +0,0 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
MIN_VOICE_PLAYBACK_SPEED = 0.5
|
||||
MAX_VOICE_PLAYBACK_SPEED = 2.0
|
||||
|
||||
|
||||
def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]:
|
||||
"""Fetch all voice providers."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
.order_by(VoiceProvider.name)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_id(
|
||||
db_session: Session, provider_id: int, include_deleted: bool = False
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by ID."""
|
||||
stmt = select(VoiceProvider).where(VoiceProvider.id == provider_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(VoiceProvider.deleted.is_(False))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default STT provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_stt.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None:
|
||||
"""Fetch the default TTS provider."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.is_default_tts.is_(True))
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def fetch_voice_provider_by_type(
|
||||
db_session: Session, provider_type: str
|
||||
) -> VoiceProvider | None:
|
||||
"""Fetch a voice provider by type."""
|
||||
return db_session.scalar(
|
||||
select(VoiceProvider)
|
||||
.where(VoiceProvider.provider_type == provider_type)
|
||||
.where(VoiceProvider.deleted.is_(False))
|
||||
)
|
||||
|
||||
|
||||
def upsert_voice_provider(
|
||||
*,
|
||||
db_session: Session,
|
||||
provider_id: int | None,
|
||||
name: str,
|
||||
provider_type: str,
|
||||
api_key: str | None,
|
||||
api_key_changed: bool,
|
||||
api_base: str | None = None,
|
||||
custom_config: dict[str, Any] | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
activate_stt: bool = False,
|
||||
activate_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create or update a voice provider."""
|
||||
provider: VoiceProvider | None = None
|
||||
|
||||
if provider_id is not None:
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
else:
|
||||
provider = VoiceProvider()
|
||||
db_session.add(provider)
|
||||
|
||||
# Apply updates
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.api_base = api_base
|
||||
provider.custom_config = custom_config
|
||||
provider.stt_model = stt_model
|
||||
provider.tts_model = tts_model
|
||||
provider.default_voice = default_voice
|
||||
|
||||
# Only update API key if explicitly changed or if provider has no key
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
db_session.flush()
|
||||
|
||||
if activate_stt:
|
||||
set_default_stt_provider(db_session=db_session, provider_id=provider.id)
|
||||
if activate_tts:
|
||||
set_default_tts_provider(db_session=db_session, provider_id=provider.id)
|
||||
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def delete_voice_provider(db_session: Session, provider_id: int) -> None:
|
||||
"""Soft-delete a voice provider by ID."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider:
|
||||
provider.deleted = True
|
||||
db_session.flush()
|
||||
|
||||
|
||||
def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other STT providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_stt.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_stt=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_stt = True
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def set_default_tts_provider(
|
||||
*, db_session: Session, provider_id: int, tts_model: str | None = None
|
||||
) -> VoiceProvider:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
# Deactivate all other TTS providers
|
||||
db_session.execute(
|
||||
update(VoiceProvider)
|
||||
.where(
|
||||
VoiceProvider.is_default_tts.is_(True),
|
||||
VoiceProvider.id != provider_id,
|
||||
)
|
||||
.values(is_default_tts=False)
|
||||
)
|
||||
|
||||
# Activate this provider
|
||||
provider.is_default_tts = True
|
||||
|
||||
# Update the TTS model if specified
|
||||
if tts_model is not None:
|
||||
provider.tts_model = tts_model
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_stt = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
provider = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"No voice provider with id {provider_id} exists.",
|
||||
)
|
||||
|
||||
provider.is_default_tts = False
|
||||
|
||||
db_session.flush()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
# User voice preferences
|
||||
|
||||
|
||||
def update_user_voice_settings(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
auto_send: bool | None = None,
|
||||
auto_playback: bool | None = None,
|
||||
playback_speed: float | None = None,
|
||||
) -> None:
|
||||
"""Update user's voice settings.
|
||||
|
||||
For all fields, None means "don't update this field".
|
||||
"""
|
||||
values: dict[str, bool | float] = {}
|
||||
|
||||
if auto_send is not None:
|
||||
values["voice_auto_send"] = auto_send
|
||||
if auto_playback is not None:
|
||||
values["voice_auto_playback"] = auto_playback
|
||||
if playback_speed is not None:
|
||||
values["voice_playback_speed"] = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed)
|
||||
)
|
||||
|
||||
if values:
|
||||
db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type]
|
||||
db_session.flush()
|
||||
@@ -1,10 +1,5 @@
|
||||
# Default value for the maximum number of tokens a chunk can hold, if none is
|
||||
# specified when creating an index.
|
||||
from onyx.configs.app_configs import (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_MAX_CHUNK_SIZE = 512
|
||||
|
||||
# Size of the dynamic list used to consider elements during kNN graph creation.
|
||||
@@ -15,43 +10,27 @@ EF_CONSTRUCTION = 256
|
||||
# quality but increase memory footprint. Values typically range between 12 - 48.
|
||||
M = 32 # Set relatively high for better accuracy.
|
||||
|
||||
# When performing hybrid search, we need to consider more candidates than the
|
||||
# number of results to be returned. This is because the scoring is hybrid and
|
||||
# the results are reordered due to the hybrid scoring. Higher = more candidates
|
||||
# for hybrid fusion = better retrieval accuracy, but results in more computation
|
||||
# per query. Imagine a simple case with a single keyword query and a single
|
||||
# vector query and we want 10 final docs. If we only fetch 10 candidates from
|
||||
# each of keyword and vector, they would have to have perfect overlap to get a
|
||||
# good hybrid ranking for the 10 results. If we fetch 1000 candidates from each,
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
|
||||
else 750
|
||||
)
|
||||
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
|
||||
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
|
||||
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
|
||||
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
|
||||
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
|
||||
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
|
||||
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
|
||||
# a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
|
||||
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
# method.
|
||||
# NOTE: "When creating a search query, you must specify k. If you provide both k
|
||||
# and ef_search, then the larger value is passed to the engine. If ef_search is
|
||||
# larger than k, you can provide the size parameter to limit the final number of
|
||||
# results to k." from
|
||||
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
|
||||
# Number of vectors to examine for top k neighbors for the HNSW method.
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
|
||||
# Since the titles are included in the contents, the embedding matches are
|
||||
# heavily downweighted as they act as a boost rather than an independent scoring
|
||||
# component.
|
||||
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
|
||||
# rather than an independent scoring component.
|
||||
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
|
||||
# Single keyword weight for both title and content (merged from former title
|
||||
# keyword + content keyword).
|
||||
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
|
||||
SEARCH_KEYWORD_WEIGHT = 0.45
|
||||
|
||||
# NOTE: It is critical that the order of these weights matches the order of the
|
||||
# sub-queries in the hybrid search.
|
||||
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
|
||||
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
|
||||
@@ -433,16 +433,12 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
hidden=fields.hidden if fields else None,
|
||||
project_ids=(
|
||||
set(user_fields.user_projects)
|
||||
# NOTE: Empty user_projects is semantically different from None
|
||||
# user_projects.
|
||||
if user_fields and user_fields.user_projects is not None
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
# NOTE: Empty personas is semantically different from None
|
||||
# personas.
|
||||
if user_fields and user_fields.personas is not None
|
||||
if user_fields and user_fields.personas
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@@ -255,12 +255,8 @@ class DocumentQuery:
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
# TODO(andrei, yuhong): We can tune this more dynamically based on
|
||||
# num_hits.
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, vector_candidates=max_results_per_subquery
|
||||
query_text, query_vector
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
@@ -289,16 +285,13 @@ class DocumentQuery:
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Max results per subquery per shard before aggregation. Ensures
|
||||
# keyword and vector subqueries contribute equally to the
|
||||
# candidate pool for hybrid fusion.
|
||||
# Max results per subquery per shard before aggregation. Ensures keyword and vector
|
||||
# subqueries contribute equally to the candidate pool for hybrid fusion.
|
||||
# Sources:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
|
||||
"pagination_depth": max_results_per_subquery,
|
||||
# Applied to all the sub-queries independently (this avoids
|
||||
# subqueries having a lot of results thrown out during
|
||||
# aggregation).
|
||||
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
|
||||
# Sources:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
|
||||
@@ -381,10 +374,9 @@ class DocumentQuery:
|
||||
def _get_hybrid_search_subqueries(
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
# The default number of neighbors to consider for knn vector similarity
|
||||
# search. This is higher than the number of results because the scoring
|
||||
# is hybrid. For a detailed breakdown, see where the default value is
|
||||
# set.
|
||||
# The default number of neighbors to consider for knn vector similarity search.
|
||||
# This is higher than the number of results because the scoring is hybrid.
|
||||
# for a detailed breakdown, see where the default value is set.
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
@@ -408,27 +400,20 @@ class DocumentQuery:
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
NOTE: Each query is independent during the search phase, there is no
|
||||
backfilling of scores for missing query components. What this means is
|
||||
that if a document was a good vector match but did not show up for
|
||||
keyword, it gets a score of 0 for the keyword component of the hybrid
|
||||
scoring. This is not as bad as just disregarding a score though as there
|
||||
is normalization applied after. So really it is "increasing" the missing
|
||||
score compared to if it was included and the range was renormalized.
|
||||
This does however mean that between docs that have high scores for say
|
||||
the vector field, the keyword scores between them are completely ignored
|
||||
unless they also showed up in the keyword query as a reasonably high
|
||||
match. TLDR, this is a bit of unique funky behavior but it seems ok.
|
||||
NOTE: Each query is independent during the search phase, there is no backfilling of scores for missing query components.
|
||||
What this means is that if a document was a good vector match but did not show up for keyword, it gets a score of 0 for
|
||||
the keyword component of the hybrid scoring. This is not as bad as just disregarding a score though as there is
|
||||
normalization applied after. So really it is "increasing" the missing score compared to if it was included and the range
|
||||
was renormalized. This does however mean that between docs that have high scores for say the vector field, the keyword
|
||||
scores between them are completely ignored unless they also showed up in the keyword query as a reasonably high match.
|
||||
TLDR, this is a bit of unique funky behavior but it seems ok.
|
||||
|
||||
NOTE: Options considered and rejected:
|
||||
- minimum_should_match: Since it's hybrid search and users often provide
|
||||
semantic queries, there is often a lot of terms, and very low number
|
||||
of meaningful keywords (and a low ratio of keywords).
|
||||
- fuzziness AUTO: Typo tolerance (0/1/2 edit distance by term length).
|
||||
It's mostly for typos as the analyzer ("english" by default) already
|
||||
does some stemming and tokenization. In testing datasets, this makes
|
||||
recall slightly worse. It also is less performant so not really any
|
||||
reason to do it.
|
||||
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
|
||||
and very low number of meaningful keywords (and a low ratio of keywords).
|
||||
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
|
||||
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
|
||||
less performant so not really any reason to do it.
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -738,13 +723,14 @@ class DocumentQuery:
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
# Knowledge scope: explicit knowledge attachments restrict what
|
||||
# an assistant can see. When none are set the assistant
|
||||
# searches everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
# project_id / persona_id are additive: they make overflowing
|
||||
# user files findable but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
@@ -772,8 +758,9 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
# Additive: widen scope to also cover overflowing user
|
||||
# files, but only when an explicit restriction is already
|
||||
# in effect.
|
||||
if project_id is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
|
||||
@@ -690,12 +690,9 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
project_ids: set[int] | None = None
|
||||
# NOTE: Empty user_projects is semantically different from None
|
||||
# user_projects.
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
# NOTE: Empty personas is semantically different from None personas.
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
|
||||
@@ -66,11 +66,6 @@ class OnyxErrorCode(Enum):
|
||||
RATE_LIMITED = ("RATE_LIMITED", 429)
|
||||
SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Payload (413)
|
||||
# ------------------------------------------------------------------
|
||||
PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Connector / Credential Errors (400-range)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -59,22 +59,6 @@ class OnyxError(Exception):
|
||||
return self._status_code_override or self.error_code.status_code
|
||||
|
||||
|
||||
def log_onyx_error(exc: OnyxError) -> None:
|
||||
detail = exc.detail
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {detail}")
|
||||
|
||||
|
||||
def onyx_error_to_json_response(exc: OnyxError) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
|
||||
|
||||
def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register a global handler that converts ``OnyxError`` to JSON responses.
|
||||
|
||||
@@ -87,5 +71,13 @@ def register_onyx_exception_handlers(app: FastAPI) -> None:
|
||||
request: Request, # noqa: ARG001
|
||||
exc: OnyxError,
|
||||
) -> JSONResponse:
|
||||
log_onyx_error(exc)
|
||||
return onyx_error_to_json_response(exc)
|
||||
status_code = exc.status_code
|
||||
if status_code >= 500:
|
||||
logger.error(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
elif status_code >= 400:
|
||||
logger.warning(f"OnyxError {exc.error_code.code}: {exc.detail}")
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status_code,
|
||||
content=exc.error_code.detail(exc.detail),
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import csv
|
||||
import gc
|
||||
import io
|
||||
import json
|
||||
@@ -19,6 +20,7 @@ from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import openpyxl
|
||||
from openpyxl.worksheet.worksheet import Worksheet
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
@@ -352,6 +354,65 @@ def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def _worksheet_to_matrix(
|
||||
worksheet: Worksheet,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Converts a singular worksheet to a matrix of values
|
||||
"""
|
||||
rows: list[list[str]] = []
|
||||
for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True):
|
||||
row = ["" if cell is None else str(cell) for cell in worksheet_row]
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]:
|
||||
"""
|
||||
Cleans a worksheet matrix by removing rows if there are N consecutive empty
|
||||
rows and removing cols if there are M consecutive empty columns
|
||||
"""
|
||||
MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is
|
||||
MAX_EMPTY_COLS = 2
|
||||
|
||||
# Row cleanup
|
||||
matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS)
|
||||
|
||||
# Column cleanup (transpose, clean, transpose back)
|
||||
transposed = list(map(list, zip(*matrix))) if matrix else []
|
||||
transposed = _remove_empty_runs(transposed, max_empty=MAX_EMPTY_COLS)
|
||||
matrix = list(map(list, zip(*transposed))) if transposed else []
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _remove_empty_runs(
|
||||
rows: list[list[str]],
|
||||
max_empty: int,
|
||||
) -> list[list[str]]:
|
||||
"""Removes entire runs of empty rows when the run length exceeds max_empty.
|
||||
|
||||
Leading and trailing empty rows are always dropped regardless of run length,
|
||||
since there is no adjacent non-empty row to bound the run.
|
||||
"""
|
||||
result: list[list[str]] = []
|
||||
empty_buffer: list[list[str]] = []
|
||||
|
||||
for row in rows:
|
||||
# Check if empty
|
||||
if not any(row):
|
||||
empty_buffer.append(row)
|
||||
else:
|
||||
# Add upto max empty rows onto the result - that's what we allow
|
||||
result.extend(empty_buffer[:max_empty])
|
||||
# Add the new non-empty row
|
||||
result.append(row)
|
||||
empty_buffer = []
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
# TODO: switch back to this approach in a few months when markitdown
|
||||
# fixes their handling of excel files
|
||||
@@ -390,30 +451,15 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
raise
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name}, skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet))
|
||||
buf = io.StringIO()
|
||||
writer = csv.writer(buf, lineterminator="\n")
|
||||
writer.writerows(sheet_matrix)
|
||||
text_content.append(buf.getvalue().rstrip("\n"))
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -123,11 +123,15 @@ class DocumentIndexingBatchAdapter:
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
document_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
)
|
||||
for document_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.pydantic_util import shallow_model_dump
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@@ -211,8 +210,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
)[0]
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk.model_construct(
|
||||
**shallow_model_dump(chunk),
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.connectors.models import Document
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.enums import SwitchoverType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.pydantic_util import shallow_model_dump
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -134,8 +133,9 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
tenant_id: str,
|
||||
ancestor_hierarchy_node_ids: list[int] | None = None,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
return cls.model_construct(
|
||||
**shallow_model_dump(index_chunk),
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
**index_chunk_data,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
|
||||
@@ -43,7 +43,6 @@ WELL_KNOWN_PROVIDER_NAMES = [
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.OLLAMA_CHAT,
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
]
|
||||
|
||||
|
||||
@@ -60,7 +59,6 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ollama": "Ollama",
|
||||
LlmProviderNames.OLLAMA_CHAT: "Ollama",
|
||||
LlmProviderNames.LM_STUDIO: "LM Studio",
|
||||
LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy",
|
||||
"groq": "Groq",
|
||||
"anyscale": "Anyscale",
|
||||
"deepseek": "DeepSeek",
|
||||
@@ -111,7 +109,6 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
LlmProviderNames.LM_STUDIO,
|
||||
LlmProviderNames.VERTEX_AI,
|
||||
LlmProviderNames.AZURE,
|
||||
LlmProviderNames.LITELLM_PROXY,
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
|
||||
@@ -3782,6 +3782,16 @@
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet-v2": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "v2"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet-v2@20241022": {
|
||||
"display_name": "Claude Sonnet 3.5 v2",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet@20240620": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
|
||||
@@ -11,8 +11,6 @@ OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
LM_STUDIO_PROVIDER_NAME = "lm_studio"
|
||||
LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY"
|
||||
|
||||
LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy"
|
||||
|
||||
# Providers that use optional Bearer auth from custom_config
|
||||
PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = {
|
||||
LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY,
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.llm.well_known_providers.auto_update_service import (
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
@@ -48,7 +47,6 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API
|
||||
}
|
||||
|
||||
|
||||
@@ -333,7 +331,6 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
BEDROCK_PROVIDER_NAME: "Amazon Bedrock",
|
||||
VERTEXAI_PROVIDER_NAME: "Google Vertex AI",
|
||||
OPENROUTER_PROVIDER_NAME: "OpenRouter",
|
||||
LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy",
|
||||
}
|
||||
|
||||
if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES:
|
||||
|
||||
@@ -44,7 +44,6 @@ from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import OAUTH_ENABLED
|
||||
from onyx.configs.app_configs import OIDC_PKCE_ENABLED
|
||||
from onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -120,9 +119,6 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.voice.api import admin_router as voice_admin_router
|
||||
from onyx.server.manage.voice.user_api import router as voice_router
|
||||
from onyx.server.manage.voice.websocket_api import router as voice_websocket_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
@@ -501,9 +497,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_router)
|
||||
include_router_with_global_prefix_prepended(application, web_search_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_router)
|
||||
include_router_with_global_prefix_prepended(application, voice_websocket_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, opensearch_migration_admin_router
|
||||
)
|
||||
@@ -604,7 +597,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
enable_pkce=OIDC_PKCE_ENABLED,
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
)
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -18,18 +18,15 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -44,51 +41,6 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -205,20 +157,6 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -269,7 +207,6 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -294,16 +231,6 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -358,7 +285,6 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -419,15 +419,12 @@ async def get_async_redis_connection() -> aioredis.Redis:
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
"""Validate auth token against Redis and return token data.
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
Args:
|
||||
token: The raw authentication token string.
|
||||
|
||||
Returns:
|
||||
Token data dict if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
@@ -442,96 +439,12 @@ async def retrieve_auth_token_data(token: str) -> dict | None:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}")
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
"""Validate auth token from request cookie. Wrapper for backwards compatibility."""
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
return await retrieve_auth_token_data(token)
|
||||
|
||||
|
||||
# WebSocket token prefix (separate from regular auth tokens)
|
||||
REDIS_WS_TOKEN_PREFIX = "ws_token:"
|
||||
# WebSocket tokens expire after 60 seconds
|
||||
WS_TOKEN_TTL_SECONDS = 60
|
||||
# Rate limit: max tokens per user per window
|
||||
WS_TOKEN_RATE_LIMIT_MAX = 10
|
||||
WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:"
|
||||
|
||||
|
||||
class WsTokenRateLimitExceeded(Exception):
|
||||
"""Raised when a user exceeds the WS token generation rate limit."""
|
||||
|
||||
|
||||
async def store_ws_token(token: str, user_id: str) -> None:
|
||||
"""Store a short-lived WebSocket authentication token in Redis.
|
||||
|
||||
Args:
|
||||
token: The generated WS token.
|
||||
user_id: The user ID to associate with this token.
|
||||
|
||||
Raises:
|
||||
WsTokenRateLimitExceeded: If the user has exceeded the rate limit.
|
||||
"""
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# Atomically increment and check rate limit to avoid TOCTOU races
|
||||
rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id
|
||||
pipe = redis.pipeline()
|
||||
pipe.incr(rate_limit_key)
|
||||
pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS)
|
||||
results = await pipe.execute()
|
||||
new_count = results[0]
|
||||
|
||||
if new_count > WS_TOKEN_RATE_LIMIT_MAX:
|
||||
# Over limit — decrement back since we won't use this slot
|
||||
await redis.decr(rate_limit_key)
|
||||
logger.warning(f"WS token rate limit exceeded for user {user_id}")
|
||||
raise WsTokenRateLimitExceeded(
|
||||
f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute."
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
|
||||
# Store the actual token
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
token_data = json.dumps({"sub": user_id})
|
||||
await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS)
|
||||
|
||||
|
||||
async def retrieve_ws_token_data(token: str) -> dict | None:
|
||||
"""Validate a WebSocket token and return the token data.
|
||||
|
||||
This uses GETDEL for atomic get-and-delete to prevent race conditions
|
||||
where the same token could be used twice.
|
||||
|
||||
Args:
|
||||
token: The WS token to validate.
|
||||
|
||||
Returns:
|
||||
Token data dict with 'sub' (user ID) if valid, None if invalid/expired.
|
||||
"""
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_WS_TOKEN_PREFIX + token
|
||||
|
||||
# Atomic get-and-delete to prevent race conditions (Redis 6.2+)
|
||||
token_data_str = await redis.getdel(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding WS token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.auth.users import current_user_with_expired_token
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -130,7 +129,6 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == current_user_from_websocket
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
or depends_fn == verify_scim_token
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
@@ -41,9 +40,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_TEMPLATES_DIR = Path(__file__).parent / "templates"
|
||||
_WEBAPP_HMR_FIXER_TEMPLATE = (_TEMPLATES_DIR / "webapp_hmr_fixer.js").read_text()
|
||||
|
||||
|
||||
def require_onyx_craft_enabled(user: User = Depends(current_user)) -> User:
|
||||
"""
|
||||
@@ -243,62 +239,18 @@ def _stream_response(response: httpx.Response) -> Iterator[bytes]:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _inject_hmr_fixer(content: bytes, session_id: str) -> bytes:
|
||||
"""Inject a script that stubs root-scoped Next HMR websocket connections."""
|
||||
base = f"/api/build/sessions/{session_id}/webapp"
|
||||
script = f"<script>{_WEBAPP_HMR_FIXER_TEMPLATE.replace('__WEBAPP_BASE__', base)}</script>"
|
||||
text = content.decode("utf-8")
|
||||
text = re.sub(
|
||||
r"(<head\b[^>]*>)",
|
||||
lambda m: m.group(0) + script,
|
||||
text,
|
||||
count=1,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
return text.encode("utf-8")
|
||||
|
||||
|
||||
def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes:
|
||||
"""Rewrite Next.js asset paths to go through the proxy."""
|
||||
import re
|
||||
|
||||
# Base path includes session_id for routing
|
||||
webapp_base_path = f"/api/build/sessions/{session_id}/webapp"
|
||||
escaped_webapp_base_path = webapp_base_path.replace("/", r"\/")
|
||||
hmr_paths = ("/_next/webpack-hmr", "/_next/hmr")
|
||||
|
||||
text = content.decode("utf-8")
|
||||
# Anchor on delimiter so already-prefixed URLs (from assetPrefix) aren't double-rewritten.
|
||||
for delim in ('"', "'", "("):
|
||||
text = text.replace(f"{delim}/_next/", f"{delim}{webapp_base_path}/_next/")
|
||||
text = re.sub(
|
||||
rf"{re.escape(delim)}https?://[^/\"')]+/_next/",
|
||||
f"{delim}{webapp_base_path}/_next/",
|
||||
text,
|
||||
)
|
||||
text = re.sub(
|
||||
rf"{re.escape(delim)}wss?://[^/\"')]+/_next/",
|
||||
f"{delim}{webapp_base_path}/_next/",
|
||||
text,
|
||||
)
|
||||
text = text.replace(r"\/_next\/", rf"{escaped_webapp_base_path}\/_next\/")
|
||||
text = re.sub(
|
||||
r"https?:\\\/\\\/[^\"']+?\\\/_next\\\/",
|
||||
rf"{escaped_webapp_base_path}\/_next\/",
|
||||
text,
|
||||
)
|
||||
text = re.sub(
|
||||
r"wss?:\\\/\\\/[^\"']+?\\\/_next\\\/",
|
||||
rf"{escaped_webapp_base_path}\/_next\/",
|
||||
text,
|
||||
)
|
||||
for hmr_path in hmr_paths:
|
||||
escaped_hmr_path = hmr_path.replace("/", r"\/")
|
||||
text = text.replace(
|
||||
f"{webapp_base_path}{hmr_path}",
|
||||
hmr_path,
|
||||
)
|
||||
text = text.replace(
|
||||
f"{escaped_webapp_base_path}{escaped_hmr_path}",
|
||||
escaped_hmr_path,
|
||||
)
|
||||
# Rewrite /_next/ paths to go through our proxy
|
||||
text = text.replace("/_next/", f"{webapp_base_path}/_next/")
|
||||
# Rewrite JSON data file fetch paths (e.g., /data.json, /data/tickets.json)
|
||||
# Matches paths like "/filename.json" or "/path/to/file.json"
|
||||
text = re.sub(
|
||||
r'"(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)"',
|
||||
f'"{webapp_base_path}\\1"',
|
||||
@@ -309,29 +261,11 @@ def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes:
|
||||
f"'{webapp_base_path}\\1'",
|
||||
text,
|
||||
)
|
||||
# Rewrite favicon
|
||||
text = text.replace('"/favicon.ico', f'"{webapp_base_path}/favicon.ico')
|
||||
return text.encode("utf-8")
|
||||
|
||||
|
||||
def _rewrite_proxy_response_headers(
|
||||
headers: dict[str, str], session_id: str
|
||||
) -> dict[str, str]:
|
||||
"""Rewrite response headers that can leak root-scoped asset URLs."""
|
||||
link = headers.get("link")
|
||||
if link:
|
||||
webapp_base_path = f"/api/build/sessions/{session_id}/webapp"
|
||||
rewritten_link = re.sub(
|
||||
r"<https?://[^>]+/_next/",
|
||||
f"<{webapp_base_path}/_next/",
|
||||
link,
|
||||
)
|
||||
rewritten_link = rewritten_link.replace(
|
||||
"</_next/", f"<{webapp_base_path}/_next/"
|
||||
)
|
||||
headers["link"] = rewritten_link
|
||||
return headers
|
||||
|
||||
|
||||
# Content types that may contain asset path references that need rewriting
|
||||
REWRITABLE_CONTENT_TYPES = {
|
||||
"text/html",
|
||||
@@ -408,17 +342,12 @@ def _proxy_request(
|
||||
for key, value in response.headers.items()
|
||||
if key.lower() not in EXCLUDED_HEADERS
|
||||
}
|
||||
response_headers = _rewrite_proxy_response_headers(
|
||||
response_headers, str(session_id)
|
||||
)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
# For HTML/CSS/JS responses, rewrite asset paths
|
||||
if any(ct in content_type for ct in REWRITABLE_CONTENT_TYPES):
|
||||
content = _rewrite_asset_paths(response.content, str(session_id))
|
||||
if "text/html" in content_type:
|
||||
content = _inject_hmr_fixer(content, str(session_id))
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
@@ -462,7 +391,7 @@ def _check_webapp_access(
|
||||
return session
|
||||
|
||||
|
||||
_OFFLINE_HTML_PATH = _TEMPLATES_DIR / "webapp_offline.html"
|
||||
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
|
||||
|
||||
|
||||
def _offline_html_response() -> Response:
|
||||
@@ -470,7 +399,6 @@ def _offline_html_response() -> Response:
|
||||
|
||||
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
|
||||
terminal window aesthetic with Minecraft-themed typing animation.
|
||||
|
||||
"""
|
||||
html = _OFFLINE_HTML_PATH.read_text()
|
||||
return Response(content=html, status_code=503, media_type="text/html")
|
||||
|
||||
@@ -732,7 +732,7 @@ def get_webapp_info(
|
||||
return WebappInfo(**webapp_info)
|
||||
|
||||
|
||||
@router.get("/{session_id}/webapp-download")
|
||||
@router.get("/{session_id}/webapp/download")
|
||||
def download_webapp(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
|
||||
@@ -1,135 +0,0 @@
|
||||
(function () {
|
||||
var WEBAPP_BASE = "__WEBAPP_BASE__";
|
||||
var PROXIED_NEXT_PREFIX = WEBAPP_BASE + "/_next/";
|
||||
var PROXIED_HMR_PREFIX = WEBAPP_BASE + "/_next/webpack-hmr";
|
||||
var PROXIED_ALT_HMR_PREFIX = WEBAPP_BASE + "/_next/hmr";
|
||||
|
||||
function isHmrWebSocketUrl(url) {
|
||||
if (!url) return false;
|
||||
try {
|
||||
var parsedUrl = new URL(String(url), window.location.href);
|
||||
return (
|
||||
parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0 ||
|
||||
parsedUrl.pathname.indexOf("/_next/hmr") === 0 ||
|
||||
parsedUrl.pathname.indexOf(PROXIED_HMR_PREFIX) === 0 ||
|
||||
parsedUrl.pathname.indexOf(PROXIED_ALT_HMR_PREFIX) === 0
|
||||
);
|
||||
} catch (e) {}
|
||||
if (typeof url === "string") {
|
||||
return (
|
||||
url.indexOf("/_next/webpack-hmr") === 0 ||
|
||||
url.indexOf("/_next/hmr") === 0 ||
|
||||
url.indexOf(PROXIED_HMR_PREFIX) === 0 ||
|
||||
url.indexOf(PROXIED_ALT_HMR_PREFIX) === 0
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function rewriteNextAssetUrl(url) {
|
||||
if (!url) return url;
|
||||
try {
|
||||
var parsedUrl = new URL(String(url), window.location.href);
|
||||
if (parsedUrl.pathname.indexOf(PROXIED_NEXT_PREFIX) === 0) {
|
||||
return parsedUrl.pathname + parsedUrl.search + parsedUrl.hash;
|
||||
}
|
||||
if (parsedUrl.pathname.indexOf("/_next/") === 0) {
|
||||
return (
|
||||
WEBAPP_BASE + parsedUrl.pathname + parsedUrl.search + parsedUrl.hash
|
||||
);
|
||||
}
|
||||
} catch (e) {}
|
||||
if (typeof url === "string") {
|
||||
if (url.indexOf(PROXIED_NEXT_PREFIX) === 0) {
|
||||
return url;
|
||||
}
|
||||
if (url.indexOf("/_next/") === 0) {
|
||||
return WEBAPP_BASE + url;
|
||||
}
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
function createEvent(eventType) {
|
||||
return typeof Event === "function"
|
||||
? new Event(eventType)
|
||||
: { type: eventType };
|
||||
}
|
||||
|
||||
function MockHmrWebSocket(url) {
|
||||
this.url = String(url);
|
||||
this.readyState = 1;
|
||||
this.bufferedAmount = 0;
|
||||
this.extensions = "";
|
||||
this.protocol = "";
|
||||
this.binaryType = "blob";
|
||||
this.onopen = null;
|
||||
this.onmessage = null;
|
||||
this.onerror = null;
|
||||
this.onclose = null;
|
||||
this._l = {};
|
||||
var socket = this;
|
||||
setTimeout(function () {
|
||||
socket._d("open", createEvent("open"));
|
||||
}, 0);
|
||||
}
|
||||
|
||||
MockHmrWebSocket.CONNECTING = 0;
|
||||
MockHmrWebSocket.OPEN = 1;
|
||||
MockHmrWebSocket.CLOSING = 2;
|
||||
MockHmrWebSocket.CLOSED = 3;
|
||||
|
||||
MockHmrWebSocket.prototype.addEventListener = function (eventType, callback) {
|
||||
(this._l[eventType] || (this._l[eventType] = [])).push(callback);
|
||||
};
|
||||
|
||||
MockHmrWebSocket.prototype.removeEventListener = function (
|
||||
eventType,
|
||||
callback,
|
||||
) {
|
||||
var listeners = this._l[eventType] || [];
|
||||
this._l[eventType] = listeners.filter(function (listener) {
|
||||
return listener !== callback;
|
||||
});
|
||||
};
|
||||
|
||||
MockHmrWebSocket.prototype._d = function (eventType, eventValue) {
|
||||
var listeners = this._l[eventType] || [];
|
||||
for (var i = 0; i < listeners.length; i++) {
|
||||
listeners[i].call(this, eventValue);
|
||||
}
|
||||
var handler = this["on" + eventType];
|
||||
if (typeof handler === "function") {
|
||||
handler.call(this, eventValue);
|
||||
}
|
||||
};
|
||||
|
||||
MockHmrWebSocket.prototype.send = function () {};
|
||||
|
||||
MockHmrWebSocket.prototype.close = function (code, reason) {
|
||||
if (this.readyState >= 2) return;
|
||||
this.readyState = 3;
|
||||
var closeEvent = createEvent("close");
|
||||
closeEvent.code = code === undefined ? 1000 : code;
|
||||
closeEvent.reason = reason || "";
|
||||
closeEvent.wasClean = true;
|
||||
this._d("close", closeEvent);
|
||||
};
|
||||
|
||||
if (window.WebSocket) {
|
||||
var OriginalWebSocket = window.WebSocket;
|
||||
window.WebSocket = function (url, protocols) {
|
||||
if (isHmrWebSocketUrl(url)) {
|
||||
return new MockHmrWebSocket(rewriteNextAssetUrl(url));
|
||||
}
|
||||
return protocols === undefined
|
||||
? new OriginalWebSocket(url)
|
||||
: new OriginalWebSocket(url, protocols);
|
||||
};
|
||||
window.WebSocket.prototype = OriginalWebSocket.prototype;
|
||||
Object.setPrototypeOf(window.WebSocket, OriginalWebSocket);
|
||||
["CONNECTING", "OPEN", "CLOSING", "CLOSED"].forEach(function (stateKey) {
|
||||
window.WebSocket[stateKey] = OriginalWebSocket[stateKey];
|
||||
});
|
||||
}
|
||||
})();
|
||||
@@ -7424,9 +7424,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/hono": {
|
||||
"version": "4.12.7",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.7.tgz",
|
||||
"integrity": "sha512-jq9l1DM0zVIvsm3lv9Nw9nlJnMNPOcAtsbsgiUhWcFzPE99Gvo6yRTlszSLLYacMeQ6quHD6hMfId8crVHvexw==",
|
||||
"version": "4.12.5",
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.12.5.tgz",
|
||||
"integrity": "sha512-3qq+FUBtlTHhtYxbxheZgY8NIFnkkC/MR8u5TTsr7YZ3wixryQ3cCwn3iZbg8p8B88iDBBAYSfZDS75t8MN7Vg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
|
||||
@@ -157,13 +157,10 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
- Images are estimated for token cost via a patch-based heuristic.
|
||||
- All other files are run through extract_file_text, which handles known
|
||||
document formats (.pdf, .docx, …) and falls back to a text-detection
|
||||
heuristic for unknown extensions (.py, .js, .rs, …).
|
||||
- Extracts text using extract_file_text for supported plain/document extensions.
|
||||
- Uses default tokenizer to compute token length.
|
||||
- If token length > threshold, reject file (unless threshold skip is enabled).
|
||||
- If text cannot be extracted, reject file.
|
||||
- If token length > 100,000, reject file (unless threshold skip is enabled).
|
||||
- If extension unsupported or text cannot be extracted, reject file.
|
||||
- Otherwise marked as acceptable.
|
||||
"""
|
||||
|
||||
@@ -220,7 +217,8 @@ def categorize_uploaded_files(
|
||||
)
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename, reason="Unsupported file contents"
|
||||
filename=filename,
|
||||
reason=f"Unsupported file type: {extension}",
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -237,10 +235,8 @@ def categorize_uploaded_files(
|
||||
results.acceptable_file_to_token_count[filename] = token_count
|
||||
continue
|
||||
|
||||
# Handle as text/document: attempt text extraction and count tokens.
|
||||
# This accepts any file that extract_file_text can handle, including
|
||||
# code files (.py, .js, .rs, etc.) via its is_text_file() fallback.
|
||||
else:
|
||||
# Otherwise, handle as text/document: extract text and count tokens
|
||||
elif extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
|
||||
if is_file_password_protected(
|
||||
file=upload.file,
|
||||
file_name=filename,
|
||||
@@ -263,10 +259,7 @@ def categorize_uploaded_files(
|
||||
if not text_content:
|
||||
logger.warning(f"No text content extracted from '{filename}'")
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Unsupported file type: {extension}",
|
||||
)
|
||||
RejectedFile(filename=filename, reason="Could not read file")
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -289,6 +282,17 @@ def categorize_uploaded_files(
|
||||
logger.warning(
|
||||
f"Failed to reset file pointer for '{filename}': {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# If not recognized as supported types above, mark unsupported
|
||||
logger.warning(
|
||||
f"Unsupported file extension '{extension}' for file '{filename}'"
|
||||
)
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename, reason=f"Unsupported file type: {extension}"
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process uploaded file '{get_safe_filename(upload)}' (error_type={type(e).__name__}, error={str(e)})"
|
||||
|
||||
@@ -58,9 +58,6 @@ from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
from onyx.server.manage.llm.models import BedrockFinalModelResponse
|
||||
from onyx.server.manage.llm.models import BedrockModelsRequest
|
||||
from onyx.server.manage.llm.models import DefaultModel
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelDetails
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderResponse
|
||||
@@ -75,7 +72,6 @@ from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
@@ -102,34 +98,6 @@ def _mask_string(value: str) -> str:
|
||||
return value[:4] + "****" + value[-4:]
|
||||
|
||||
|
||||
def _sync_fetched_models(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[SyncModelEntry],
|
||||
source_label: str,
|
||||
) -> None:
|
||||
"""Sync fetched models to DB for the given provider.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of SyncModelEntry objects describing the fetched models
|
||||
source_label: Human-readable label for log messages (e.g. "Bedrock", "LiteLLM")
|
||||
"""
|
||||
try:
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=provider_name,
|
||||
models=models,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new {source_label} models to provider '{provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync {source_label} models to DB: {e}")
|
||||
|
||||
|
||||
# Keys in custom_config that contain sensitive credentials
|
||||
_SENSITIVE_CONFIG_KEYS = {
|
||||
"vertex_credentials",
|
||||
@@ -995,20 +963,27 @@ def get_bedrock_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in results
|
||||
],
|
||||
source_label="Bedrock",
|
||||
)
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Bedrock models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Bedrock models to DB: {e}")
|
||||
|
||||
return results
|
||||
|
||||
@@ -1126,20 +1101,27 @@ def get_ollama_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="Ollama",
|
||||
)
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new Ollama models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1228,20 +1210,27 @@ def get_openrouter_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
try:
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="OpenRouter",
|
||||
)
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new OpenRouter models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync OpenRouter models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
@@ -1335,119 +1324,26 @@ def get_lm_studio_available_models(
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.name,
|
||||
display_name=r.display_name,
|
||||
max_input_tokens=r.max_input_tokens,
|
||||
supports_image_input=r.supports_image_input,
|
||||
)
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LM Studio",
|
||||
)
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
@admin_router.post("/litellm/available-models")
|
||||
def get_litellm_available_models(
|
||||
request: LitellmModelsRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LitellmFinalModelResponse]:
|
||||
"""Fetch available models from Litellm proxy /v1/models endpoint."""
|
||||
response_json = _get_litellm_models_response(
|
||||
api_key=request.api_key, api_base=request.api_base
|
||||
)
|
||||
|
||||
models = response_json.get("data", [])
|
||||
if not isinstance(models, list) or len(models) == 0:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No models found from your Litellm endpoint",
|
||||
)
|
||||
|
||||
results: list[LitellmFinalModelResponse] = []
|
||||
for model in models:
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
model_name=model_details.id,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse Litellm model entry",
|
||||
extra={"error": str(e), "item": str(model)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No compatible models found from Litellm",
|
||||
)
|
||||
|
||||
sorted_results = sorted(results, key=lambda m: m.model_name.lower())
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
_sync_fetched_models(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=[
|
||||
SyncModelEntry(
|
||||
name=r.model_name,
|
||||
display_name=r.model_name,
|
||||
)
|
||||
models_to_sync = [
|
||||
{
|
||||
"name": r.name,
|
||||
"display_name": r.display_name,
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in sorted_results
|
||||
],
|
||||
source_label="LiteLLM",
|
||||
)
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
provider_name=request.provider_name,
|
||||
models=models_to_sync,
|
||||
)
|
||||
if new_count > 0:
|
||||
logger.info(
|
||||
f"Added {new_count} new LM Studio models to provider '{request.provider_name}'"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync LM Studio models to DB: {e}")
|
||||
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_litellm_models_response(api_key: str, api_base: str) -> dict:
|
||||
"""Perform GET to Litellm proxy /api/v1/models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/v1/models"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Authentication failed: invalid or missing API key for LiteLLM proxy.",
|
||||
)
|
||||
elif e.response.status_code == 404:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"LiteLLM models endpoint not found at {url}. "
|
||||
"Please verify the API base URL.",
|
||||
)
|
||||
else:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
except Exception as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Failed to fetch LiteLLM models: {e}",
|
||||
)
|
||||
|
||||
@@ -420,32 +420,3 @@ class LLMProviderResponse(BaseModel, Generic[T]):
|
||||
default_text=default_text,
|
||||
default_vision=default_vision,
|
||||
)
|
||||
|
||||
|
||||
class SyncModelEntry(BaseModel):
|
||||
"""Typed model for syncing fetched models to the DB."""
|
||||
|
||||
name: str
|
||||
display_name: str
|
||||
max_input_tokens: int | None = None
|
||||
supports_image_input: bool = False
|
||||
|
||||
|
||||
class LitellmModelsRequest(BaseModel):
|
||||
api_key: str
|
||||
api_base: str
|
||||
provider_name: str | None = None # Optional: to save models to existing provider
|
||||
|
||||
|
||||
class LitellmModelDetails(BaseModel):
|
||||
"""Response model for Litellm proxy /api/v1/models endpoint"""
|
||||
|
||||
id: str # Model ID (e.g. "gpt-4o")
|
||||
object: str # "model"
|
||||
created: int # Unix timestamp in seconds
|
||||
owned_by: str # Provider name (e.g. "openai")
|
||||
|
||||
|
||||
class LitellmFinalModelResponse(BaseModel):
|
||||
provider_name: str # Provider name (e.g. "openai")
|
||||
model_name: str # Model ID (e.g. "gpt-4o")
|
||||
|
||||
@@ -85,11 +85,6 @@ class UserPreferences(BaseModel):
|
||||
chat_background: str | None = None
|
||||
default_app_mode: DefaultAppMode = DefaultAppMode.CHAT
|
||||
|
||||
# Voice preferences
|
||||
voice_auto_send: bool | None = None
|
||||
voice_auto_playback: bool | None = None
|
||||
voice_playback_speed: float | None = None
|
||||
|
||||
# controls which tools are enabled for the user for a specific assistant
|
||||
assistant_specific_configs: UserSpecificAssistantPreferences | None = None
|
||||
|
||||
@@ -169,9 +164,6 @@ class UserInfo(BaseModel):
|
||||
theme_preference=user.theme_preference,
|
||||
chat_background=user.chat_background,
|
||||
default_app_mode=user.default_app_mode,
|
||||
voice_auto_send=user.voice_auto_send,
|
||||
voice_auto_playback=user.voice_auto_playback,
|
||||
voice_playback_speed=user.voice_playback_speed,
|
||||
assistant_specific_configs=assistant_specific_configs,
|
||||
)
|
||||
),
|
||||
@@ -248,12 +240,6 @@ class ChatBackgroundRequest(BaseModel):
|
||||
chat_background: str | None
|
||||
|
||||
|
||||
class VoiceSettingsUpdateRequest(BaseModel):
|
||||
auto_send: bool | None = None
|
||||
auto_playback: bool | None = None
|
||||
playback_speed: float | None = Field(default=None, ge=0.5, le=2.0)
|
||||
|
||||
|
||||
class PersonalizationUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -19,7 +18,6 @@ from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.anonymous_user import fetch_anonymous_user_info
|
||||
@@ -69,14 +67,11 @@ from onyx.db.user_preferences import update_user_role
|
||||
from onyx.db.user_preferences import update_user_shortcut_enabled
|
||||
from onyx.db.user_preferences import update_user_temperature_override_enabled
|
||||
from onyx.db.user_preferences import update_user_theme_preference
|
||||
from onyx.db.users import batch_get_user_groups
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_all_accepted_users
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.db.users import get_page_of_filtered_users
|
||||
from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import get_user_counts_by_role_and_status
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
@@ -103,7 +98,6 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -209,91 +203,14 @@ def list_accepted_users(
|
||||
total_items=0,
|
||||
)
|
||||
|
||||
user_ids = [user.id for user in filtered_accepted_users]
|
||||
groups_by_user = batch_get_user_groups(db_session, user_ids)
|
||||
|
||||
# Batch-fetch SCIM mappings to mark synced users
|
||||
scim_synced_ids: set[UUID] = set()
|
||||
try:
|
||||
from onyx.db.models import ScimUserMapping
|
||||
|
||||
scim_mappings = db_session.scalars(
|
||||
select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
scim_synced_ids = set(scim_mappings)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch SCIM mappings; marking all users as non-synced",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
FullUserSnapshot.from_user_model(
|
||||
user,
|
||||
groups=[
|
||||
UserGroupInfo(id=gid, name=gname)
|
||||
for gid, gname in groups_by_user.get(user.id, [])
|
||||
],
|
||||
is_scim_synced=user.id in scim_synced_ids,
|
||||
)
|
||||
for user in filtered_accepted_users
|
||||
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
|
||||
],
|
||||
total_items=total_accepted_users_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/manage/users/accepted/all", tags=PUBLIC_API_TAGS)
|
||||
def list_all_accepted_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullUserSnapshot]:
|
||||
"""Returns all accepted users without pagination.
|
||||
Used by the admin Users page for client-side filtering/sorting."""
|
||||
users = get_all_accepted_users(db_session=db_session)
|
||||
|
||||
if not users:
|
||||
return []
|
||||
|
||||
user_ids = [user.id for user in users]
|
||||
groups_by_user = batch_get_user_groups(db_session, user_ids)
|
||||
|
||||
# Batch-fetch SCIM mappings to mark synced users
|
||||
scim_synced_ids: set[UUID] = set()
|
||||
try:
|
||||
from onyx.db.models import ScimUserMapping
|
||||
|
||||
scim_mappings = db_session.scalars(
|
||||
select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
scim_synced_ids = set(scim_mappings)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to fetch SCIM mappings; marking all users as non-synced",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return [
|
||||
FullUserSnapshot.from_user_model(
|
||||
user,
|
||||
groups=[
|
||||
UserGroupInfo(id=gid, name=gname)
|
||||
for gid, gname in groups_by_user.get(user.id, [])
|
||||
],
|
||||
is_scim_synced=user.id in scim_synced_ids,
|
||||
)
|
||||
for user in users
|
||||
]
|
||||
|
||||
|
||||
@router.get("/manage/users/counts")
|
||||
def get_user_counts(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, dict[str, int]]:
|
||||
return get_user_counts_by_role_and_status(db_session)
|
||||
|
||||
|
||||
@router.get("/manage/users/invited", tags=PUBLIC_API_TAGS)
|
||||
def list_invited_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
@@ -352,10 +269,24 @@ def list_all_users(
|
||||
if accepted_page is None or invited_page is None or slack_users_page is None:
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot.from_user_model(user) for user in accepted_users
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
],
|
||||
slack_users=[
|
||||
FullUserSnapshot.from_user_model(user) for user in slack_users
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
||||
accepted_pages=1,
|
||||
@@ -365,10 +296,26 @@ def list_all_users(
|
||||
|
||||
# Otherwise, return paginated results
|
||||
return AllUsersResponse(
|
||||
accepted=[FullUserSnapshot.from_user_model(user) for user in accepted_users][
|
||||
accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE
|
||||
],
|
||||
slack_users=[FullUserSnapshot.from_user_model(user) for user in slack_users][
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
slack_users_page
|
||||
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
||||
* USERS_PAGE_SIZE
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.models import VoiceOption
|
||||
from onyx.server.manage.voice.models import VoiceProviderTestRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpdateSuccess
|
||||
from onyx.server.manage.voice.models import VoiceProviderUpsertRequest
|
||||
from onyx.server.manage.voice.models import VoiceProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/voice")
|
||||
|
||||
|
||||
def _validate_voice_api_base(provider_type: str, api_base: str | None) -> str | None:
|
||||
"""Validate and normalize provider api_base / target URI."""
|
||||
if api_base is None:
|
||||
return None
|
||||
|
||||
allow_private_network = provider_type.lower() == "azure"
|
||||
try:
|
||||
return validate_outbound_http_url(
|
||||
api_base, allow_private_network=allow_private_network
|
||||
)
|
||||
except (ValueError, SSRFException) as e:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Invalid target URI: {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _provider_to_view(provider: VoiceProvider) -> VoiceProviderView:
|
||||
"""Convert a VoiceProvider model to a VoiceProviderView."""
|
||||
return VoiceProviderView(
|
||||
id=provider.id,
|
||||
name=provider.name,
|
||||
provider_type=provider.provider_type,
|
||||
is_default_stt=provider.is_default_stt,
|
||||
is_default_tts=provider.is_default_tts,
|
||||
stt_model=provider.stt_model,
|
||||
tts_model=provider.tts_model,
|
||||
default_voice=provider.default_voice,
|
||||
has_api_key=bool(provider.api_key),
|
||||
target_uri=provider.api_base, # api_base stores the target URI for Azure
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/providers")
|
||||
def list_voice_providers(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceProviderView]:
|
||||
"""List all configured voice providers."""
|
||||
providers = fetch_voice_providers(db_session)
|
||||
return [_provider_to_view(provider) for provider in providers]
|
||||
|
||||
|
||||
@admin_router.post("/providers")
|
||||
async def upsert_voice_provider_endpoint(
|
||||
request: VoiceProviderUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Create or update a voice provider."""
|
||||
api_key = request.api_key
|
||||
api_key_changed = request.api_key_changed
|
||||
|
||||
# If llm_provider_id is specified, copy the API key from that LLM provider
|
||||
if request.llm_provider_id is not None:
|
||||
llm_provider = db_session.get(LLMProviderModel, request.llm_provider_id)
|
||||
if llm_provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_FOUND,
|
||||
f"LLM provider with id {request.llm_provider_id} not found.",
|
||||
)
|
||||
if llm_provider.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Selected LLM provider has no API key configured.",
|
||||
)
|
||||
api_key = llm_provider.api_key.get_value(apply_mask=False)
|
||||
api_key_changed = True
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
provider = upsert_voice_provider(
|
||||
db_session=db_session,
|
||||
provider_id=request.id,
|
||||
name=request.name,
|
||||
provider_type=request.provider_type,
|
||||
api_key=api_key,
|
||||
api_key_changed=api_key_changed,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config,
|
||||
stt_model=request.stt_model,
|
||||
tts_model=request.tts_model,
|
||||
default_voice=request.default_voice,
|
||||
activate_stt=request.activate_stt,
|
||||
activate_tts=request.activate_tts,
|
||||
)
|
||||
|
||||
# Validate credentials before committing - rollback on failure
|
||||
try:
|
||||
voice_provider = get_voice_provider(provider)
|
||||
await voice_provider.validate_credentials()
|
||||
except OnyxError:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Voice provider credential validation failed on save: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
str(e),
|
||||
) from e
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.delete(
|
||||
"/providers/{provider_id}", status_code=204, response_class=Response
|
||||
)
|
||||
def delete_voice_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a voice provider."""
|
||||
delete_voice_provider(db_session, provider_id)
|
||||
db_session.commit()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-stt")
|
||||
def activate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default STT provider."""
|
||||
provider = set_default_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-stt")
|
||||
def deactivate_stt_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default STT status from a voice provider."""
|
||||
deactivate_stt_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/activate-tts")
|
||||
def activate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
tts_model: str | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderView:
|
||||
"""Set a voice provider as the default TTS provider."""
|
||||
provider = set_default_tts_provider(
|
||||
db_session=db_session, provider_id=provider_id, tts_model=tts_model
|
||||
)
|
||||
db_session.commit()
|
||||
return _provider_to_view(provider)
|
||||
|
||||
|
||||
@admin_router.post("/providers/{provider_id}/deactivate-tts")
|
||||
def deactivate_tts_provider_endpoint(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Remove the default TTS status from a voice provider."""
|
||||
deactivate_tts_provider(db_session=db_session, provider_id=provider_id)
|
||||
db_session.commit()
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.post("/providers/test")
|
||||
async def test_voice_provider(
|
||||
request: VoiceProviderTestRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceProviderUpdateSuccess:
|
||||
"""Test a voice provider connection by making a real API call."""
|
||||
api_key = request.api_key
|
||||
|
||||
if request.use_stored_key:
|
||||
existing_provider = fetch_voice_provider_by_type(
|
||||
db_session, request.provider_type
|
||||
)
|
||||
if existing_provider is None or not existing_provider.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"API key is required. Either provide api_key or set use_stored_key to true.",
|
||||
)
|
||||
|
||||
# Use target_uri if provided, otherwise fall back to api_base
|
||||
api_base = _validate_voice_api_base(
|
||||
request.provider_type, request.target_uri or request.api_base
|
||||
)
|
||||
|
||||
# Create a temporary VoiceProvider for testing (not saved to DB)
|
||||
temp_provider = VoiceProvider(
|
||||
name="__test__",
|
||||
provider_type=request.provider_type,
|
||||
api_base=api_base,
|
||||
custom_config=request.custom_config or {},
|
||||
)
|
||||
temp_provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
# Validate credentials with a real API call
|
||||
try:
|
||||
await provider.validate_credentials()
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Voice provider connection test failed: {e}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
str(e),
|
||||
) from e
|
||||
|
||||
logger.info(f"Voice provider test succeeded for {request.provider_type}.")
|
||||
return VoiceProviderUpdateSuccess()
|
||||
|
||||
|
||||
@admin_router.get("/providers/{provider_id}/voices")
|
||||
def get_provider_voices(
|
||||
provider_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider."""
|
||||
provider_db = fetch_voice_provider_by_id(db_session, provider_id)
|
||||
if provider_db is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Voice provider not found.")
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR, "Provider has no API key configured."
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
|
||||
|
||||
@admin_router.get("/voices")
|
||||
def get_voices_by_type(
|
||||
provider_type: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> list[VoiceOption]:
|
||||
"""Get available voices for a provider type.
|
||||
|
||||
For providers like ElevenLabs and OpenAI, this fetches voices
|
||||
without requiring an existing provider configuration.
|
||||
"""
|
||||
# Create a temporary VoiceProvider to get static voice list
|
||||
temp_provider = VoiceProvider(
|
||||
name="__temp__",
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(temp_provider)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(exc)) from exc
|
||||
|
||||
return [VoiceOption(**voice) for voice in provider.get_available_voices()]
|
||||
@@ -1,95 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class VoiceProviderView(BaseModel):
|
||||
"""Response model for voice provider listing."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
is_default_stt: bool
|
||||
is_default_tts: bool
|
||||
stt_model: str | None
|
||||
tts_model: str | None
|
||||
default_voice: str | None
|
||||
has_api_key: bool = Field(
|
||||
default=False,
|
||||
description="Indicates whether an API key is stored for this provider.",
|
||||
)
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderUpdateSuccess(BaseModel):
|
||||
"""Simple status response for voice provider actions."""
|
||||
|
||||
status: str = "ok"
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""Voice option returned by voice providers."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class VoiceProviderUpsertRequest(BaseModel):
|
||||
"""Request model for creating or updating a voice provider."""
|
||||
|
||||
id: int | None = Field(default=None, description="Existing provider ID to update.")
|
||||
name: str
|
||||
provider_type: str # "openai", "azure", "elevenlabs"
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for the provider.",
|
||||
)
|
||||
api_key_changed: bool = Field(
|
||||
default=False,
|
||||
description="Set to true when providing a new API key for an existing provider.",
|
||||
)
|
||||
llm_provider_id: int | None = Field(
|
||||
default=None,
|
||||
description="If set, copies the API key from the specified LLM provider.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
stt_model: str | None = None
|
||||
tts_model: str | None = None
|
||||
default_voice: str | None = None
|
||||
activate_stt: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default STT provider after upsert.",
|
||||
)
|
||||
activate_tts: bool = Field(
|
||||
default=False,
|
||||
description="If true, sets this provider as the default TTS provider after upsert.",
|
||||
)
|
||||
|
||||
|
||||
class VoiceProviderTestRequest(BaseModel):
|
||||
"""Request model for testing a voice provider connection."""
|
||||
|
||||
provider_type: str
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for testing. If not provided, use_stored_key must be true.",
|
||||
)
|
||||
use_stored_key: bool = Field(
|
||||
default=False,
|
||||
description="If true, use the stored API key for this provider type.",
|
||||
)
|
||||
api_base: str | None = None
|
||||
target_uri: str | None = Field(
|
||||
default=None,
|
||||
description="Target URI for Azure Speech Services (maps to api_base).",
|
||||
)
|
||||
custom_config: dict[str, Any] | None = None
|
||||
@@ -1,251 +0,0 @@
|
||||
import secrets
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import store_ws_token
|
||||
from onyx.redis.redis_pool import WsTokenRateLimitExceeded
|
||||
from onyx.server.manage.models import VoiceSettingsUpdateRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
# Max audio file size: 25MB (Whisper limit)
|
||||
MAX_AUDIO_SIZE = 25 * 1024 * 1024
|
||||
# Chunk size for streaming uploads (8KB)
|
||||
UPLOAD_READ_CHUNK_SIZE = 8192
|
||||
|
||||
|
||||
class VoiceStatusResponse(BaseModel):
|
||||
stt_enabled: bool
|
||||
tts_enabled: bool
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_voice_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> VoiceStatusResponse:
|
||||
"""Check whether STT and TTS providers are configured and ready."""
|
||||
stt_provider = fetch_default_stt_provider(db_session)
|
||||
tts_provider = fetch_default_tts_provider(db_session)
|
||||
return VoiceStatusResponse(
|
||||
stt_enabled=stt_provider is not None and stt_provider.api_key is not None,
|
||||
tts_enabled=tts_provider is not None and tts_provider.api_key is not None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(
|
||||
audio: UploadFile = File(...),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Transcribe audio to text using the default STT provider."""
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No speech-to-text provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Read in chunks to enforce size limit during streaming (prevents OOM attacks)
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
while chunk := await audio.read(UPLOAD_READ_CHUNK_SIZE):
|
||||
total += len(chunk)
|
||||
if total > MAX_AUDIO_SIZE:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.PAYLOAD_TOO_LARGE,
|
||||
f"Audio file too large. Maximum size is {MAX_AUDIO_SIZE // (1024 * 1024)}MB.",
|
||||
)
|
||||
chunks.append(chunk)
|
||||
audio_data = b"".join(chunks)
|
||||
|
||||
# Extract format from filename
|
||||
filename = audio.filename or "audio.webm"
|
||||
audio_format = filename.rsplit(".", 1)[-1] if "." in filename else "webm"
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
try:
|
||||
text = await provider.transcribe(audio_data, audio_format)
|
||||
return {"text": text}
|
||||
except NotImplementedError as exc:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.NOT_IMPLEMENTED,
|
||||
f"Speech-to-text not implemented for {provider_db.provider_type}.",
|
||||
) from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"Transcription failed: {exc}")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Transcription failed. Please try again.",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(
|
||||
text: str | None = Query(
|
||||
default=None, description="Text to synthesize", max_length=4096
|
||||
),
|
||||
voice: str | None = Query(default=None, description="Voice ID to use"),
|
||||
speed: float | None = Query(
|
||||
default=None, description="Playback speed (0.5-2.0)", ge=0.5, le=2.0
|
||||
),
|
||||
user: User = Depends(current_user),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Synthesize text to speech using the default TTS provider.
|
||||
|
||||
Accepts parameters via query string for streaming compatibility.
|
||||
"""
|
||||
logger.info(
|
||||
f"TTS request: text length={len(text) if text else 0}, voice={voice}, speed={speed}"
|
||||
)
|
||||
|
||||
if not text:
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Text is required")
|
||||
|
||||
# Use short-lived session to fetch provider config, then release connection
|
||||
# before starting the long-running streaming response
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.error("No TTS provider configured")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No text-to-speech provider configured. Please contact your administrator.",
|
||||
)
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.error("TTS provider has no API key")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"Voice provider API key not configured.",
|
||||
)
|
||||
|
||||
# Use request voice or provider default
|
||||
final_voice = voice or provider_db.default_voice
|
||||
# Use explicit None checks to avoid falsy float issues (0.0 would be skipped with `or`)
|
||||
final_speed = (
|
||||
speed
|
||||
if speed is not None
|
||||
else (
|
||||
user.voice_playback_speed
|
||||
if user.voice_playback_speed is not None
|
||||
else 1.0
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"TTS using provider: {provider_db.provider_type}, voice: {final_voice}, speed: {final_speed}"
|
||||
)
|
||||
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
except ValueError as exc:
|
||||
logger.error(f"Failed to get voice provider: {exc}")
|
||||
raise OnyxError(OnyxErrorCode.INTERNAL_ERROR, str(exc)) from exc
|
||||
|
||||
# Session is now closed - streaming response won't hold DB connection
|
||||
async def audio_stream() -> AsyncIterator[bytes]:
|
||||
try:
|
||||
chunk_count = 0
|
||||
async for chunk in provider.synthesize_stream(
|
||||
text=text, voice=final_voice, speed=final_speed
|
||||
):
|
||||
chunk_count += 1
|
||||
yield chunk
|
||||
logger.info(f"TTS streaming complete: {chunk_count} chunks sent")
|
||||
except NotImplementedError as exc:
|
||||
logger.error(f"TTS not implemented: {exc}")
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Synthesis failed: {exc}")
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
audio_stream(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3",
|
||||
# Allow streaming by not setting content-length
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/settings")
|
||||
def update_voice_settings(
|
||||
request: VoiceSettingsUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Update user's voice settings."""
|
||||
update_user_voice_settings(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
auto_send=request.auto_send,
|
||||
auto_playback=request.auto_playback,
|
||||
playback_speed=request.playback_speed,
|
||||
)
|
||||
db_session.commit()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WSTokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
@router.post("/ws-token")
|
||||
async def get_ws_token(
|
||||
user: User = Depends(current_user),
|
||||
) -> WSTokenResponse:
|
||||
"""
|
||||
Generate a short-lived token for WebSocket authentication.
|
||||
|
||||
This token should be passed as a query parameter when connecting
|
||||
to voice WebSocket endpoints (e.g., /voice/transcribe/stream?token=xxx).
|
||||
|
||||
The token expires after 60 seconds and is single-use.
|
||||
Rate limited to 10 tokens per minute per user.
|
||||
"""
|
||||
token = secrets.token_urlsafe(32)
|
||||
try:
|
||||
await store_ws_token(token, str(user.id))
|
||||
except WsTokenRateLimitExceeded:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.RATE_LIMITED,
|
||||
"Too many token requests. Please wait before requesting another.",
|
||||
)
|
||||
return WSTokenResponse(token=token)
|
||||
@@ -1,860 +0,0 @@
|
||||
"""WebSocket API for streaming speech-to-text and text-to-speech."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import WebSocket
|
||||
from fastapi import WebSocketDisconnect
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user_from_websocket
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import User
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.factory import get_voice_provider
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/voice")
|
||||
|
||||
|
||||
# Transcribe every ~0.5 seconds of audio (webm/opus is ~2-4KB/s, so ~1-2KB per 0.5s)
|
||||
MIN_CHUNK_BYTES = 1500
|
||||
VOICE_DISABLE_STREAMING_FALLBACK = (
|
||||
os.environ.get("VOICE_DISABLE_STREAMING_FALLBACK", "").lower() == "true"
|
||||
)
|
||||
|
||||
# WebSocket size limits to prevent memory exhaustion attacks
|
||||
WS_MAX_MESSAGE_SIZE = 64 * 1024 # 64KB per message (OWASP recommendation)
|
||||
WS_MAX_TOTAL_BYTES = 25 * 1024 * 1024 # 25MB total per connection (matches REST API)
|
||||
WS_MAX_TEXT_MESSAGE_SIZE = 16 * 1024 # 16KB for text/JSON messages
|
||||
WS_MAX_TTS_TEXT_LENGTH = 4096 # Max text length per synthesize call (matches REST API)
|
||||
|
||||
|
||||
class ChunkedTranscriber:
|
||||
"""Fallback transcriber for providers without streaming support."""
|
||||
|
||||
def __init__(self, provider: Any, audio_format: str = "webm"):
|
||||
self.provider = provider
|
||||
self.audio_format = audio_format
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.full_audio = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
self.transcripts: list[str] = []
|
||||
|
||||
async def add_chunk(self, chunk: bytes) -> str | None:
|
||||
"""Add audio chunk. Returns transcript if enough audio accumulated."""
|
||||
self.chunk_buffer.write(chunk)
|
||||
self.full_audio.write(chunk)
|
||||
self.chunk_bytes += len(chunk)
|
||||
|
||||
if self.chunk_bytes >= MIN_CHUNK_BYTES:
|
||||
return await self._transcribe_chunk()
|
||||
return None
|
||||
|
||||
async def _transcribe_chunk(self) -> str | None:
|
||||
"""Transcribe current chunk and append to running transcript."""
|
||||
audio_data = self.chunk_buffer.getvalue()
|
||||
if not audio_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
transcript = await self.provider.transcribe(audio_data, self.audio_format)
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
|
||||
if transcript and transcript.strip():
|
||||
self.transcripts.append(transcript.strip())
|
||||
return " ".join(self.transcripts)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}")
|
||||
self.chunk_buffer = io.BytesIO()
|
||||
self.chunk_bytes = 0
|
||||
return None
|
||||
|
||||
async def flush(self) -> str:
|
||||
"""Get final transcript from full audio for best accuracy."""
|
||||
full_audio_data = self.full_audio.getvalue()
|
||||
if full_audio_data:
|
||||
try:
|
||||
transcript = await self.provider.transcribe(
|
||||
full_audio_data, self.audio_format
|
||||
)
|
||||
if transcript and transcript.strip():
|
||||
return transcript.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Final transcription error: {e}")
|
||||
return " ".join(self.transcripts)
|
||||
|
||||
|
||||
async def handle_streaming_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: StreamingTranscriberProtocol,
|
||||
) -> None:
|
||||
"""Handle transcription using native streaming API."""
|
||||
logger.info("Streaming transcription: starting handler")
|
||||
last_transcript = ""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
async def receive_transcripts() -> None:
|
||||
"""Background task to receive and send transcripts."""
|
||||
nonlocal last_transcript
|
||||
logger.info("Streaming transcription: starting transcript receiver")
|
||||
while True:
|
||||
result: TranscriptResult | None = await transcriber.receive_transcript()
|
||||
if result is None: # End of stream
|
||||
logger.info("Streaming transcription: transcript stream ended")
|
||||
break
|
||||
# Send if text changed OR if VAD detected end of speech (for auto-send trigger)
|
||||
if result.text and (result.text != last_transcript or result.is_vad_end):
|
||||
last_transcript = result.text
|
||||
logger.debug(
|
||||
f"Streaming transcription: got transcript: {result.text[:50]}... "
|
||||
f"(is_vad_end={result.is_vad_end})"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": result.text,
|
||||
"is_final": result.is_vad_end,
|
||||
}
|
||||
)
|
||||
|
||||
# Start receiving transcripts in background
|
||||
receive_task = asyncio.create_task(receive_transcripts())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Streaming transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Streaming transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Streaming transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
await transcriber.send_audio(message["bytes"])
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(
|
||||
f"Streaming transcription: received text message: {data}"
|
||||
)
|
||||
if data.get("type") == "end":
|
||||
logger.info(
|
||||
"Streaming transcription: end signal received, closing transcriber"
|
||||
)
|
||||
final_transcript = await transcriber.close()
|
||||
receive_task.cancel()
|
||||
logger.info(
|
||||
"Streaming transcription: final transcript: "
|
||||
f"{final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif data.get("type") == "reset":
|
||||
# Reset accumulated transcript after auto-send
|
||||
logger.info(
|
||||
"Streaming transcription: reset signal received, clearing transcript"
|
||||
)
|
||||
transcriber.reset_transcript()
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming transcription: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
receive_task.cancel()
|
||||
try:
|
||||
await receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info(
|
||||
f"Streaming transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
async def handle_chunked_transcription(
|
||||
websocket: WebSocket,
|
||||
transcriber: ChunkedTranscriber,
|
||||
) -> None:
|
||||
"""Handle transcription using chunked batch API."""
|
||||
logger.info("Chunked transcription: starting handler")
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info(
|
||||
f"Chunked transcription: client disconnected after {chunk_count} chunks ({total_bytes} bytes)"
|
||||
)
|
||||
break
|
||||
|
||||
if "bytes" in message:
|
||||
chunk_size = len(message["bytes"])
|
||||
|
||||
# Enforce per-message size limit
|
||||
if chunk_size > WS_MAX_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked transcription: message too large ({chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
# Enforce total connection size limit
|
||||
if total_bytes + chunk_size > WS_MAX_TOTAL_BYTES:
|
||||
logger.warning(
|
||||
f"Chunked transcription: total size limit exceeded ({total_bytes + chunk_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Total size limit exceeded"}
|
||||
)
|
||||
break
|
||||
|
||||
chunk_count += 1
|
||||
total_bytes += chunk_size
|
||||
logger.debug(
|
||||
f"Chunked transcription: received chunk {chunk_count} ({chunk_size} bytes, total: {total_bytes})"
|
||||
)
|
||||
|
||||
transcript = await transcriber.add_chunk(message["bytes"])
|
||||
if transcript:
|
||||
logger.debug(
|
||||
f"Chunked transcription: got transcript: {transcript[:50]}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": transcript,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
logger.debug(f"Chunked transcription: received text message: {data}")
|
||||
if data.get("type") == "end":
|
||||
logger.info("Chunked transcription: end signal received, flushing")
|
||||
final_transcript = await transcriber.flush()
|
||||
logger.info(
|
||||
f"Chunked transcription: final transcript: {final_transcript[:100] if final_transcript else '(empty)'}..."
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "transcript",
|
||||
"text": final_transcript,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Chunked transcription: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Chunked transcription: handler finished. Processed {chunk_count} chunks, {total_bytes} total bytes"
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/transcribe/stream")
|
||||
async def websocket_transcribe(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming speech-to-text.
|
||||
|
||||
Protocol:
|
||||
- Client sends binary audio chunks
|
||||
- Server sends JSON: {"type": "transcript", "text": "...", "is_final": false}
|
||||
- Client sends JSON {"type": "end"} to signal end
|
||||
- Server responds with final transcript and closes
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/transcribe/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket transcribe: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket transcribe: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_transcriber = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get STT provider
|
||||
logger.info("WebSocket transcribe: fetching STT provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_stt_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket transcribe: no default STT provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No speech-to-text provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket transcribe: STT provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Speech-to-text provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket transcribe: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket transcribe: voice provider created, streaming supported: {provider.supports_streaming_stt()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_stt():
|
||||
logger.info("WebSocket transcribe: using native streaming STT")
|
||||
try:
|
||||
streaming_transcriber = await provider.create_streaming_transcriber()
|
||||
logger.info(
|
||||
"WebSocket transcribe: streaming transcriber created successfully"
|
||||
)
|
||||
await handle_streaming_transcription(websocket, streaming_transcriber)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket transcribe: failed to create streaming transcriber: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming STT failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info("WebSocket transcribe: falling back to chunked STT")
|
||||
# Browser stream provides raw PCM16 chunks over WebSocket.
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
else:
|
||||
# Fall back to chunked transcription
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming STT",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket transcribe: using chunked STT (provider doesn't support streaming)"
|
||||
)
|
||||
chunked_transcriber = ChunkedTranscriber(provider, audio_format="pcm16")
|
||||
await handle_chunked_transcription(websocket, chunked_transcriber)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket transcribe: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket transcribe: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_transcriber:
|
||||
try:
|
||||
await streaming_transcriber.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket transcribe: connection closed")
|
||||
|
||||
|
||||
async def handle_streaming_synthesis(
|
||||
websocket: WebSocket,
|
||||
synthesizer: StreamingSynthesizerProtocol,
|
||||
) -> None:
|
||||
"""Handle TTS using native streaming API."""
|
||||
logger.info("Streaming synthesis: starting handler")
|
||||
|
||||
async def send_audio() -> None:
|
||||
"""Background task to send audio chunks to client."""
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
while True:
|
||||
audio_chunk = await synthesizer.receive_audio()
|
||||
if audio_chunk is None:
|
||||
logger.info(
|
||||
f"Streaming synthesis: audio stream ended, sent {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
try:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Streaming synthesis: sent audio_done to client")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send audio_done: {e}"
|
||||
)
|
||||
break
|
||||
if audio_chunk: # Skip empty chunks
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to send chunk: {e}"
|
||||
)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"Streaming synthesis: send_audio cancelled after {chunk_count} chunks"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: send_audio error: {e}")
|
||||
|
||||
send_task: asyncio.Task | None = None
|
||||
disconnected = False
|
||||
|
||||
try:
|
||||
while not disconnected:
|
||||
try:
|
||||
message = await websocket.receive()
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
break
|
||||
|
||||
msg_type = message.get("type", "unknown") # type: ignore[possibly-undefined]
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Streaming synthesis: client disconnected")
|
||||
disconnected = True
|
||||
break
|
||||
|
||||
if "text" in message:
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
|
||||
if data.get("type") == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
# Start audio receiver on first text chunk so playback
|
||||
# can begin before the full assistant response completes.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
logger.debug(
|
||||
f"Streaming synthesis: forwarding text chunk ({len(text)} chars)"
|
||||
)
|
||||
await synthesizer.send_text(text)
|
||||
|
||||
elif data.get("type") == "end":
|
||||
logger.info("Streaming synthesis: end signal received")
|
||||
|
||||
# Ensure receiver is active even if no prior text chunks arrived.
|
||||
if send_task is None:
|
||||
send_task = asyncio.create_task(send_audio())
|
||||
|
||||
# Signal end of input
|
||||
if hasattr(synthesizer, "flush"):
|
||||
await synthesizer.flush()
|
||||
|
||||
# Wait for all audio to be sent
|
||||
logger.info(
|
||||
"Streaming synthesis: waiting for audio stream to complete"
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Streaming synthesis: timeout waiting for audio"
|
||||
)
|
||||
break
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f"Streaming synthesis: failed to parse JSON: {message.get('text', '')[:100]}"
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Streaming synthesis: client disconnected during synthesis")
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming synthesis: error: {e}", exc_info=True)
|
||||
finally:
|
||||
if send_task and not send_task.done():
|
||||
logger.info("Streaming synthesis: waiting for send_task to finish")
|
||||
try:
|
||||
await asyncio.wait_for(send_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Streaming synthesis: timeout waiting for send_task")
|
||||
send_task.cancel()
|
||||
try:
|
||||
await send_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Streaming synthesis: handler finished")
|
||||
|
||||
|
||||
async def handle_chunked_synthesis(
|
||||
websocket: WebSocket,
|
||||
provider: Any,
|
||||
first_message: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Fallback TTS handler using provider.synthesize_stream.
|
||||
|
||||
Args:
|
||||
websocket: The WebSocket connection
|
||||
provider: Voice provider instance
|
||||
first_message: Optional first message already received (used when falling
|
||||
back from streaming mode, where the first message was already consumed)
|
||||
"""
|
||||
logger.info("Chunked synthesis: starting handler")
|
||||
text_buffer: list[str] = []
|
||||
voice: str | None = None
|
||||
speed = 1.0
|
||||
|
||||
# Process pre-received message if provided
|
||||
pending_message = first_message
|
||||
|
||||
try:
|
||||
while True:
|
||||
if pending_message is not None:
|
||||
message = pending_message
|
||||
pending_message = None
|
||||
else:
|
||||
message = await websocket.receive()
|
||||
msg_type = message.get("type", "unknown")
|
||||
|
||||
if msg_type == "websocket.disconnect":
|
||||
logger.info("Chunked synthesis: client disconnected")
|
||||
break
|
||||
|
||||
if "text" not in message:
|
||||
continue
|
||||
|
||||
# Enforce text message size limit
|
||||
msg_size = len(message["text"])
|
||||
if msg_size > WS_MAX_TEXT_MESSAGE_SIZE:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text message too large ({msg_size} bytes)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Message too large"}
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
"Chunked synthesis: failed to parse JSON: "
|
||||
f"{message.get('text', '')[:100]}"
|
||||
)
|
||||
continue
|
||||
|
||||
msg_data_type = data.get("type") # type: ignore[possibly-undefined]
|
||||
if msg_data_type == "synthesize":
|
||||
text = data.get("text", "")
|
||||
# Enforce per-text size limit
|
||||
if len(text) > WS_MAX_TTS_TEXT_LENGTH:
|
||||
logger.warning(
|
||||
f"Chunked synthesis: text too long ({len(text)} chars)"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "Text too long"}
|
||||
)
|
||||
continue
|
||||
if text:
|
||||
text_buffer.append(text)
|
||||
logger.debug(
|
||||
f"Chunked synthesis: buffered text ({len(text)} chars), "
|
||||
f"total buffered: {len(text_buffer)} chunks"
|
||||
)
|
||||
if isinstance(data.get("voice"), str) and data["voice"]:
|
||||
voice = data["voice"]
|
||||
if isinstance(data.get("speed"), (int, float)):
|
||||
speed = float(data["speed"])
|
||||
elif msg_data_type == "end":
|
||||
logger.info("Chunked synthesis: end signal received")
|
||||
full_text = " ".join(text_buffer).strip()
|
||||
if not full_text:
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info("Chunked synthesis: no text, sent audio_done")
|
||||
break
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
logger.info(
|
||||
f"Chunked synthesis: sending full text ({len(full_text)} chars)"
|
||||
)
|
||||
async for audio_chunk in provider.synthesize_stream(
|
||||
full_text, voice=voice, speed=speed
|
||||
):
|
||||
if not audio_chunk:
|
||||
continue
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_chunk)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
await websocket.send_json({"type": "audio_done"})
|
||||
logger.info(
|
||||
f"Chunked synthesis: sent audio_done after {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
break
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("Chunked synthesis: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"Chunked synthesis: error: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
logger.info("Chunked synthesis: handler finished")
|
||||
|
||||
|
||||
@router.websocket("/synthesize/stream")
|
||||
async def websocket_synthesize(
|
||||
websocket: WebSocket,
|
||||
_user: User = Depends(current_user_from_websocket),
|
||||
) -> None:
|
||||
"""
|
||||
WebSocket endpoint for streaming text-to-speech.
|
||||
|
||||
Protocol:
|
||||
- Client sends JSON: {"type": "synthesize", "text": "...", "voice": "...", "speed": 1.0}
|
||||
- Server sends binary audio chunks
|
||||
- Server sends JSON: {"type": "audio_done"} when synthesis completes
|
||||
- Client sends JSON {"type": "end"} to close connection
|
||||
|
||||
Authentication:
|
||||
Requires `token` query parameter (e.g., /voice/synthesize/stream?token=xxx).
|
||||
Applies same auth checks as HTTP endpoints (verification, role checks).
|
||||
"""
|
||||
logger.info("WebSocket synthesize: connection request received (authenticated)")
|
||||
|
||||
try:
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket synthesize: connection accepted")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: failed to accept connection: {e}")
|
||||
return
|
||||
|
||||
streaming_synthesizer: StreamingSynthesizerProtocol | None = None
|
||||
provider = None
|
||||
|
||||
try:
|
||||
# Get TTS provider
|
||||
logger.info("WebSocket synthesize: fetching TTS provider from database")
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
provider_db = fetch_default_tts_provider(db_session)
|
||||
if provider_db is None:
|
||||
logger.warning(
|
||||
"WebSocket synthesize: no default TTS provider configured"
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "No text-to-speech provider configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
if not provider_db.api_key:
|
||||
logger.warning("WebSocket synthesize: TTS provider has no API key")
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Text-to-speech provider has no API key configured",
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"WebSocket synthesize: creating voice provider: {provider_db.provider_type}"
|
||||
)
|
||||
try:
|
||||
provider = get_voice_provider(provider_db)
|
||||
logger.info(
|
||||
f"WebSocket synthesize: voice provider created, streaming TTS supported: {provider.supports_streaming_tts()}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create voice provider: {e}"
|
||||
)
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
return
|
||||
|
||||
# Use native streaming if provider supports it
|
||||
if provider.supports_streaming_tts():
|
||||
logger.info("WebSocket synthesize: using native streaming TTS")
|
||||
message = None # Initialize to avoid UnboundLocalError in except block
|
||||
try:
|
||||
# Wait for initial config message with voice/speed
|
||||
message = await websocket.receive()
|
||||
voice = None
|
||||
speed = 1.0
|
||||
if "text" in message:
|
||||
try:
|
||||
data = json.loads(message["text"])
|
||||
voice = data.get("voice")
|
||||
speed = data.get("speed", 1.0)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
streaming_synthesizer = await provider.create_streaming_synthesizer(
|
||||
voice=voice, speed=speed
|
||||
)
|
||||
logger.info(
|
||||
"WebSocket synthesize: streaming synthesizer created successfully"
|
||||
)
|
||||
await handle_streaming_synthesis(websocket, streaming_synthesizer)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"WebSocket synthesize: failed to create streaming synthesizer: {e}"
|
||||
)
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": f"Streaming TTS failed: {e}"}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: falling back to chunked TTS synthesis"
|
||||
)
|
||||
# Pass the first message so it's not lost in the fallback
|
||||
await handle_chunked_synthesis(
|
||||
websocket, provider, first_message=message
|
||||
)
|
||||
else:
|
||||
if VOICE_DISABLE_STREAMING_FALLBACK:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Provider doesn't support streaming TTS",
|
||||
}
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"WebSocket synthesize: using chunked TTS (provider doesn't support streaming)"
|
||||
)
|
||||
await handle_chunked_synthesis(websocket, provider)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket synthesize: client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket synthesize: unhandled error: {e}", exc_info=True)
|
||||
try:
|
||||
# Send generic error to avoid leaking sensitive details
|
||||
await websocket.send_json(
|
||||
{"type": "error", "message": "An unexpected error occurred"}
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if streaming_synthesizer:
|
||||
try:
|
||||
await streaming_synthesizer.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("WebSocket synthesize: connection closed")
|
||||
@@ -1,4 +1,3 @@
|
||||
import datetime
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
@@ -32,41 +31,21 @@ class MinimalUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class UserGroupInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
groups: list[UserGroupInfo]
|
||||
is_scim_synced: bool
|
||||
|
||||
@classmethod
|
||||
def from_user_model(
|
||||
cls,
|
||||
user: User,
|
||||
groups: list[UserGroupInfo] | None = None,
|
||||
is_scim_synced: bool = False,
|
||||
) -> "FullUserSnapshot":
|
||||
def from_user_model(cls, user: User) -> "FullUserSnapshot":
|
||||
return cls(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at,
|
||||
groups=groups or [],
|
||||
is_scim_synced=is_scim_synced,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def shallow_model_dump(model_instance: BaseModel) -> dict[str, Any]:
|
||||
"""Like model_dump(), but returns references to field values instead of
|
||||
deep copies. Use with model_construct() to avoid unnecessary memory
|
||||
duplication when building subclass instances."""
|
||||
return {
|
||||
field_name: getattr(model_instance, field_name)
|
||||
for field_name in model_instance.__class__.model_fields
|
||||
}
|
||||
@@ -140,44 +140,6 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed.
|
||||
SSRFException: If URL fails SSRF checks.
|
||||
"""
|
||||
normalized_url = url.strip()
|
||||
if not normalized_url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL must contain a hostname")
|
||||
|
||||
if parsed.username or parsed.password:
|
||||
raise SSRFException("URLs with embedded credentials are not allowed.")
|
||||
|
||||
hostname = parsed.hostname.lower()
|
||||
if hostname in BLOCKED_HOSTNAMES:
|
||||
raise SSRFException(f"Access to hostname '{parsed.hostname}' is not allowed.")
|
||||
|
||||
if not allow_private_network:
|
||||
_validate_and_resolve_url(normalized_url)
|
||||
|
||||
return normalized_url
|
||||
|
||||
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
|
||||
def get_voice_provider(provider: VoiceProvider) -> VoiceProviderInterface:
|
||||
"""
|
||||
Factory function to get the appropriate voice provider implementation.
|
||||
|
||||
Args:
|
||||
provider: VoiceProvider model instance (can be from DB or constructed temporarily)
|
||||
|
||||
Returns:
|
||||
VoiceProviderInterface implementation
|
||||
|
||||
Raises:
|
||||
ValueError: If provider_type is not supported
|
||||
"""
|
||||
provider_type = provider.provider_type.lower()
|
||||
|
||||
# Handle both SensitiveValue (from DB) and plain string (from temp model)
|
||||
if provider.api_key is None:
|
||||
api_key = None
|
||||
elif hasattr(provider.api_key, "get_value"):
|
||||
# SensitiveValue from database
|
||||
api_key = provider.api_key.get_value(apply_mask=False)
|
||||
else:
|
||||
# Plain string from temporary model
|
||||
api_key = provider.api_key # type: ignore[assignment]
|
||||
api_base = provider.api_base
|
||||
custom_config = provider.custom_config
|
||||
stt_model = provider.stt_model
|
||||
tts_model = provider.tts_model
|
||||
default_voice = provider.default_voice
|
||||
|
||||
if provider_type == "openai":
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
return OpenAIVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "azure":
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
return AzureVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config or {},
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
elif provider_type == "elevenlabs":
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
return ElevenLabsVoiceProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stt_model=stt_model,
|
||||
tts_model=tts_model,
|
||||
default_voice=default_voice,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported voice provider type: {provider_type}")
|
||||
@@ -1,182 +0,0 @@
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TranscriptResult(BaseModel):
|
||||
"""Result from streaming transcription."""
|
||||
|
||||
text: str
|
||||
"""The accumulated transcript text."""
|
||||
|
||||
is_vad_end: bool = False
|
||||
"""True if VAD detected end of speech (silence). Use for auto-send."""
|
||||
|
||||
|
||||
class StreamingTranscriberProtocol(Protocol):
|
||||
"""Protocol for streaming transcription sessions."""
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
...
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""
|
||||
Receive next transcript update.
|
||||
|
||||
Returns:
|
||||
TranscriptResult with accumulated text and VAD status, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
...
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
...
|
||||
|
||||
|
||||
class StreamingSynthesizerProtocol(Protocol):
|
||||
"""Protocol for streaming TTS sessions (real-time text-to-speech)."""
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish connection to TTS provider."""
|
||||
...
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized."""
|
||||
...
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""
|
||||
Receive next audio chunk.
|
||||
|
||||
Returns:
|
||||
Audio bytes, or None when stream ends.
|
||||
"""
|
||||
...
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input and wait for pending audio."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
...
|
||||
|
||||
|
||||
class VoiceProviderInterface(ABC):
|
||||
"""Abstract base class for voice providers (STT and TTS)."""
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Convert audio to text (Speech-to-Text).
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio stream (Text-to-Speech).
|
||||
|
||||
Streams audio chunks progressively for lower latency playback.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (e.g., "alloy", "echo"), or None for default
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_credentials(self) -> None:
|
||||
"""
|
||||
Validate that the provider credentials are correct by making a
|
||||
lightweight API call. Raises on failure.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available voices for this provider.
|
||||
|
||||
Returns:
|
||||
List of voice dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available STT models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""
|
||||
Get list of available TTS models for this provider.
|
||||
|
||||
Returns:
|
||||
List of model dictionaries with 'id' and 'name' keys
|
||||
"""
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Returns True if this provider supports streaming STT."""
|
||||
return False
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Returns True if this provider supports real-time streaming TTS."""
|
||||
return False
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, audio_format: str = "webm"
|
||||
) -> StreamingTranscriberProtocol:
|
||||
"""
|
||||
Create a streaming transcription session.
|
||||
|
||||
Args:
|
||||
audio_format: Audio format being sent (e.g., "webm", "pcm16")
|
||||
|
||||
Returns:
|
||||
A streaming transcriber that can send audio chunks and receive transcripts
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming STT is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming STT not supported by this provider")
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> "StreamingSynthesizerProtocol":
|
||||
"""
|
||||
Create a streaming TTS session for real-time audio synthesis.
|
||||
|
||||
Args:
|
||||
voice: Voice identifier
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Returns:
|
||||
A streaming synthesizer that can send text and receive audio chunks
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If streaming TTS is not supported
|
||||
"""
|
||||
raise NotImplementedError("Streaming TTS not supported by this provider")
|
||||
@@ -1,625 +0,0 @@
|
||||
"""Azure Speech Services voice provider for STT and TTS.
|
||||
|
||||
Azure supports:
|
||||
- **STT**: Batch transcription via REST API (audio/wav POST) and real-time
|
||||
streaming via the Azure Speech SDK (push audio stream with continuous
|
||||
recognition). The SDK handles VAD natively through its recognizing/recognized
|
||||
events.
|
||||
- **TTS**: SSML-based synthesis via REST API (streaming response) and real-time
|
||||
synthesis via the Speech SDK. Text is escaped with ``xml.sax.saxutils.escape``
|
||||
and attributes with ``quoteattr`` to prevent SSML injection.
|
||||
|
||||
Both modes support Azure cloud endpoints (region-based URLs) and self-hosted
|
||||
Speech containers (custom endpoint URLs). The ``speech_region`` is validated to
|
||||
contain only ``[a-z0-9-]`` to prevent URL injection.
|
||||
|
||||
The Azure Speech SDK (``azure-cognitiveservices-speech``) is an optional C
|
||||
extension dependency — it is imported lazily inside streaming methods so the
|
||||
provider can still be instantiated and used for REST-based operations without it.
|
||||
|
||||
See https://learn.microsoft.com/en-us/azure/cognitive-services/speech-service/
|
||||
for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import re
|
||||
import struct
|
||||
import wave
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from xml.sax.saxutils import escape
|
||||
from xml.sax.saxutils import quoteattr
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# SSML namespace — W3C standard for Speech Synthesis Markup Language.
|
||||
# This is a fixed W3C specification and will not change.
|
||||
SSML_NAMESPACE = "http://www.w3.org/2001/10/synthesis"
|
||||
|
||||
# Common Azure Neural voices
|
||||
AZURE_VOICES = [
|
||||
{"id": "en-US-JennyNeural", "name": "Jenny (en-US, Female)"},
|
||||
{"id": "en-US-GuyNeural", "name": "Guy (en-US, Male)"},
|
||||
{"id": "en-US-AriaNeural", "name": "Aria (en-US, Female)"},
|
||||
{"id": "en-US-DavisNeural", "name": "Davis (en-US, Male)"},
|
||||
{"id": "en-US-AmberNeural", "name": "Amber (en-US, Female)"},
|
||||
{"id": "en-US-AnaNeural", "name": "Ana (en-US, Female)"},
|
||||
{"id": "en-US-BrandonNeural", "name": "Brandon (en-US, Male)"},
|
||||
{"id": "en-US-ChristopherNeural", "name": "Christopher (en-US, Male)"},
|
||||
{"id": "en-US-CoraNeural", "name": "Cora (en-US, Female)"},
|
||||
{"id": "en-GB-SoniaNeural", "name": "Sonia (en-GB, Female)"},
|
||||
{"id": "en-GB-RyanNeural", "name": "Ryan (en-GB, Male)"},
|
||||
]
|
||||
|
||||
|
||||
class AzureStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
input_sample_rate: int = 24000,
|
||||
target_sample_rate: int = 16000,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._accumulated_transcript = ""
|
||||
self._recognizer: Any = None
|
||||
self._audio_stream: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech recognizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk # type: ignore
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming STT. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
|
||||
audio_format = speechsdk.audio.AudioStreamFormat(
|
||||
samples_per_second=16000,
|
||||
bits_per_sample=16,
|
||||
channels=1,
|
||||
)
|
||||
self._audio_stream = speechsdk.audio.PushAudioInputStream(audio_format)
|
||||
audio_config = speechsdk.audio.AudioConfig(stream=self._audio_stream)
|
||||
|
||||
self._recognizer = speechsdk.SpeechRecognizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=audio_config,
|
||||
)
|
||||
|
||||
transcriber = self
|
||||
|
||||
def on_recognizing(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
full_text = transcriber._accumulated_transcript
|
||||
if full_text:
|
||||
full_text += " " + evt.result.text
|
||||
else:
|
||||
full_text = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(text=full_text, is_vad_end=False),
|
||||
)
|
||||
|
||||
def on_recognized(evt: Any) -> None:
|
||||
if evt.result.text and transcriber._loop and not transcriber._closed:
|
||||
if transcriber._accumulated_transcript:
|
||||
transcriber._accumulated_transcript += " " + evt.result.text
|
||||
else:
|
||||
transcriber._accumulated_transcript = evt.result.text
|
||||
transcriber._loop.call_soon_threadsafe(
|
||||
transcriber._transcript_queue.put_nowait,
|
||||
TranscriptResult(
|
||||
text=transcriber._accumulated_transcript, is_vad_end=True
|
||||
),
|
||||
)
|
||||
|
||||
self._recognizer.recognizing.connect(on_recognizing)
|
||||
self._recognizer.recognized.connect(on_recognized)
|
||||
self._recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to Azure."""
|
||||
if self._audio_stream and not self._closed:
|
||||
self._audio_stream.write(self._resample_pcm16(chunk))
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
num_samples = len(data) // 2
|
||||
if num_samples == 0:
|
||||
return b""
|
||||
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
resampled: list[int] = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Stop recognition and return final transcript."""
|
||||
self._closed = True
|
||||
if self._recognizer:
|
||||
self._recognizer.stop_continuous_recognition_async()
|
||||
if self._audio_stream:
|
||||
self._audio_stream.close()
|
||||
self._loop = None
|
||||
return self._accumulated_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript."""
|
||||
self._accumulated_transcript = ""
|
||||
|
||||
|
||||
class AzureStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using Azure Speech SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
region: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
voice: str = "en-US-JennyNeural",
|
||||
speed: float = 1.0,
|
||||
):
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.region = region
|
||||
self.endpoint = endpoint
|
||||
self.voice = voice
|
||||
self.speed = max(0.5, min(2.0, speed))
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._synthesizer: Any = None
|
||||
self._closed = False
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize Azure Speech synthesizer with push stream."""
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"Azure Speech SDK is required for streaming TTS. "
|
||||
"Install `azure-cognitiveservices-speech`."
|
||||
) from e
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connecting")
|
||||
|
||||
# Store the event loop for thread-safe queue operations
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Use endpoint for self-hosted containers, region for Azure cloud
|
||||
if self.endpoint:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
else:
|
||||
speech_config = speechsdk.SpeechConfig(
|
||||
subscription=self.api_key,
|
||||
region=self.region,
|
||||
)
|
||||
speech_config.speech_synthesis_voice_name = self.voice
|
||||
# Use MP3 format for streaming - compatible with MediaSource Extensions
|
||||
speech_config.set_speech_synthesis_output_format(
|
||||
speechsdk.SpeechSynthesisOutputFormat.Audio16Khz64KBitRateMonoMp3
|
||||
)
|
||||
|
||||
# Create synthesizer with pull audio output stream
|
||||
self._synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=speech_config,
|
||||
audio_config=None, # We'll manually handle audio
|
||||
)
|
||||
|
||||
# Connect to synthesis events
|
||||
self._synthesizer.synthesizing.connect(self._on_synthesizing)
|
||||
self._synthesizer.synthesis_completed.connect(self._on_completed)
|
||||
|
||||
self._logger.info("AzureStreamingSynthesizer: connected")
|
||||
|
||||
def _on_synthesizing(self, evt: Any) -> None:
|
||||
"""Called when audio chunk is available (runs in Azure SDK thread)."""
|
||||
if evt.result.audio_data and self._loop and not self._closed:
|
||||
# Thread-safe way to put item in async queue
|
||||
self._loop.call_soon_threadsafe(
|
||||
self._audio_queue.put_nowait, evt.result.audio_data
|
||||
)
|
||||
|
||||
def _on_completed(self, _evt: Any) -> None:
|
||||
"""Called when synthesis is complete (runs in Azure SDK thread)."""
|
||||
if self._loop and not self._closed:
|
||||
self._loop.call_soon_threadsafe(self._audio_queue.put_nowait, None)
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized using SSML for prosody control."""
|
||||
if self._synthesizer and not self._closed:
|
||||
# Build SSML with prosody for speed control
|
||||
rate = f"{int((self.speed - 1) * 100):+d}%"
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(self.voice)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
# Use speak_ssml_async for SSML support (includes speed/prosody)
|
||||
self._synthesizer.speak_ssml_async(ssml)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for pending audio."""
|
||||
# Azure SDK handles flushing automatically
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._synthesizer:
|
||||
self._synthesizer.synthesis_completed.disconnect_all()
|
||||
self._synthesizer.synthesizing.disconnect_all()
|
||||
self._loop = None
|
||||
|
||||
|
||||
class AzureVoiceProvider(VoiceProviderInterface):
|
||||
"""Azure Speech Services voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, Any],
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.custom_config = custom_config
|
||||
raw_speech_region = (
|
||||
custom_config.get("speech_region")
|
||||
or self._extract_speech_region_from_uri(api_base)
|
||||
or ""
|
||||
)
|
||||
self.speech_region = self._validate_speech_region(raw_speech_region)
|
||||
self.stt_model = stt_model
|
||||
self.tts_model = tts_model
|
||||
self.default_voice = default_voice or "en-US-JennyNeural"
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_cloud_url(uri: str | None) -> bool:
|
||||
"""Check if URI is an Azure cloud endpoint (vs custom/self-hosted)."""
|
||||
if not uri:
|
||||
return False
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return False
|
||||
return hostname.endswith(
|
||||
(
|
||||
".speech.microsoft.com",
|
||||
".api.cognitive.microsoft.com",
|
||||
".cognitiveservices.azure.com",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_speech_region_from_uri(uri: str | None) -> str | None:
|
||||
"""Extract Azure speech region from endpoint URI.
|
||||
|
||||
Note: Custom domains (*.cognitiveservices.azure.com) contain the resource
|
||||
name, not the region. For custom domains, the region must be specified
|
||||
explicitly via custom_config["speech_region"].
|
||||
"""
|
||||
if not uri:
|
||||
return None
|
||||
# Accepted examples:
|
||||
# - https://eastus.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
# - https://eastus.stt.speech.microsoft.com/speech/recognition/...
|
||||
# - https://westus.api.cognitive.microsoft.com/
|
||||
#
|
||||
# NOT supported (requires explicit speech_region config):
|
||||
# - https://<resource>.cognitiveservices.azure.com/ (resource name != region)
|
||||
try:
|
||||
hostname = (urlparse(uri).hostname or "").lower()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
stt_tts_match = re.match(
|
||||
r"^([a-z0-9-]+)\.(?:tts|stt)\.speech\.microsoft\.com$", hostname
|
||||
)
|
||||
if stt_tts_match:
|
||||
return stt_tts_match.group(1)
|
||||
|
||||
api_match = re.match(
|
||||
r"^([a-z0-9-]+)\.api\.cognitive\.microsoft\.com$", hostname
|
||||
)
|
||||
if api_match:
|
||||
return api_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_speech_region(speech_region: str) -> str:
|
||||
normalized_region = speech_region.strip().lower()
|
||||
if not normalized_region:
|
||||
return ""
|
||||
if not re.fullmatch(r"[a-z0-9-]+", normalized_region):
|
||||
raise ValueError(
|
||||
"Invalid Azure speech_region. Use lowercase letters, digits, and hyphens only."
|
||||
)
|
||||
return normalized_region
|
||||
|
||||
def _get_stt_url(self) -> str:
|
||||
"""Get the STT endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/speech/recognition/conversation/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return (
|
||||
f"https://{self.speech_region}.stt.speech.microsoft.com/"
|
||||
"speech/recognition/conversation/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
def _get_tts_url(self) -> str:
|
||||
"""Get the TTS endpoint URL (auto-detects cloud vs self-hosted)."""
|
||||
if self.api_base and not self._is_azure_cloud_url(self.api_base):
|
||||
# Self-hosted container endpoint
|
||||
return f"{self.api_base.rstrip('/')}/cognitiveservices/v1"
|
||||
# Azure cloud endpoint
|
||||
return f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
|
||||
def _is_self_hosted(self) -> bool:
|
||||
"""Check if using self-hosted container vs Azure cloud."""
|
||||
return bool(self.api_base and not self._is_azure_cloud_url(self.api_base))
|
||||
|
||||
@staticmethod
|
||||
def _pcm16_to_wav(pcm_data: bytes, sample_rate: int = 24000) -> bytes:
|
||||
"""Wrap raw PCM16 mono bytes into a WAV container."""
|
||||
buffer = io.BytesIO()
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for STT")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for STT (cloud mode)")
|
||||
|
||||
normalized_format = audio_format.lower()
|
||||
payload = audio_data
|
||||
content_type = f"audio/{normalized_format}"
|
||||
|
||||
# WebSocket chunked fallback sends raw PCM16 bytes.
|
||||
if normalized_format in {"pcm", "pcm16", "raw"}:
|
||||
payload = self._pcm16_to_wav(audio_data, sample_rate=24000)
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format in {"wav", "wave"}:
|
||||
content_type = "audio/wav"
|
||||
elif normalized_format == "webm":
|
||||
content_type = "audio/webm; codecs=opus"
|
||||
|
||||
url = self._get_stt_url()
|
||||
params = {"language": "en-US", "format": "detailed"}
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": content_type,
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url, params=params, headers=headers, data=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure STT failed: {error_text}")
|
||||
result = await response.json()
|
||||
|
||||
if result.get("RecognitionStatus") != "Success":
|
||||
return ""
|
||||
nbest = result.get("NBest") or []
|
||||
if nbest and isinstance(nbest, list):
|
||||
display = nbest[0].get("Display")
|
||||
if isinstance(display, str):
|
||||
return display
|
||||
display_text = result.get("DisplayText", "")
|
||||
return display_text if isinstance(display_text, str) else ""
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using Azure TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.5 to 2.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required for TTS")
|
||||
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required for TTS (cloud mode)")
|
||||
|
||||
voice_name = voice or self.default_voice
|
||||
|
||||
# Clamp speed to valid range and convert to rate format
|
||||
speed = max(0.5, min(2.0, speed))
|
||||
rate = f"{int((speed - 1) * 100):+d}%" # e.g., 1.0 -> "+0%", 1.5 -> "+50%"
|
||||
|
||||
# Build SSML with escaped text and quoted attributes to prevent injection
|
||||
escaped_text = escape(text)
|
||||
ssml = f"""<speak version='1.0' xmlns='{SSML_NAMESPACE}' xml:lang='en-US'>
|
||||
<voice name={quoteattr(voice_name)}>
|
||||
<prosody rate='{rate}'>{escaped_text}</prosody>
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
url = self._get_tts_url()
|
||||
|
||||
headers = {
|
||||
"Ocp-Apim-Subscription-Key": self.api_key,
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "audio-16khz-128kbitrate-mono-mp3",
|
||||
"User-Agent": "Onyx",
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=ssml) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise RuntimeError(f"Azure TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate Azure credentials by listing available voices."""
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key required")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError("Azure speech region required (cloud mode)")
|
||||
|
||||
url = f"https://{self.speech_region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
|
||||
if self._is_self_hosted():
|
||||
url = f"{(self.api_base or '').rstrip('/')}/cognitiveservices/voices/list"
|
||||
|
||||
headers = {"Ocp-Apim-Subscription-Key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status in (401, 403):
|
||||
raise RuntimeError("Invalid Azure API key.")
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Azure credential validation failed.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common Azure Neural voices."""
|
||||
return AZURE_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "default", "name": "Azure Speech Recognition"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "neural", "name": "Neural TTS"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""Azure supports streaming STT via Speech SDK."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""Azure supports real-time streaming TTS via Speech SDK."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> AzureStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming transcription (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
transcriber = AzureStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
input_sample_rate=24000,
|
||||
target_sample_rate=16000,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> AzureStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
if not self._is_self_hosted() and not self.speech_region:
|
||||
raise ValueError(
|
||||
"Speech region required for Azure streaming TTS (cloud mode)"
|
||||
)
|
||||
|
||||
# Use endpoint for self-hosted, region for cloud
|
||||
synthesizer = AzureStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
region=self.speech_region if not self._is_self_hosted() else None,
|
||||
endpoint=self.api_base if self._is_self_hosted() else None,
|
||||
voice=voice or self.default_voice or "en-US-JennyNeural",
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,876 +0,0 @@
|
||||
"""ElevenLabs voice provider for STT and TTS.
|
||||
|
||||
ElevenLabs supports:
|
||||
- **STT**: Scribe API (batch via REST, streaming via WebSocket with Scribe v2 Realtime).
|
||||
The streaming endpoint sends base64-encoded PCM16 audio chunks and receives JSON
|
||||
transcript messages (partial_transcript, committed_transcript, utterance_end).
|
||||
- **TTS**: Text-to-speech via REST streaming and WebSocket stream-input.
|
||||
The WebSocket variant accepts incremental text chunks and returns audio in order,
|
||||
enabling low-latency playback before the full text is available.
|
||||
|
||||
See https://elevenlabs.io/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
# Default ElevenLabs API base URL
|
||||
DEFAULT_ELEVENLABS_API_BASE = "https://api.elevenlabs.io"
|
||||
|
||||
# Default sample rates for STT streaming
|
||||
DEFAULT_INPUT_SAMPLE_RATE = 24000 # What the browser frontend sends
|
||||
DEFAULT_TARGET_SAMPLE_RATE = 16000 # What ElevenLabs Scribe expects
|
||||
|
||||
# Default streaming TTS output format
|
||||
DEFAULT_TTS_OUTPUT_FORMAT = "mp3_44100_64"
|
||||
|
||||
# Default TTS voice settings
|
||||
DEFAULT_VOICE_STABILITY = 0.5
|
||||
DEFAULT_VOICE_SIMILARITY_BOOST = 0.75
|
||||
|
||||
# Chunk length schedule for streaming TTS (optimized for real-time playback)
|
||||
DEFAULT_CHUNK_LENGTH_SCHEDULE = [120, 160, 250, 290]
|
||||
|
||||
# Default STT streaming VAD configuration
|
||||
DEFAULT_VAD_SILENCE_THRESHOLD_SECS = 1.0
|
||||
DEFAULT_VAD_THRESHOLD = 0.4
|
||||
DEFAULT_MIN_SPEECH_DURATION_MS = 100
|
||||
DEFAULT_MIN_SILENCE_DURATION_MS = 300
|
||||
|
||||
|
||||
class ElevenLabsSTTMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs Scribe Realtime STT API."""
|
||||
|
||||
SESSION_STARTED = "session_started"
|
||||
PARTIAL_TRANSCRIPT = "partial_transcript"
|
||||
COMMITTED_TRANSCRIPT = "committed_transcript"
|
||||
UTTERANCE_END = "utterance_end"
|
||||
SESSION_ENDED = "session_ended"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ElevenLabsTTSMessageType(StrEnum):
|
||||
"""Message types from ElevenLabs stream-input TTS API."""
|
||||
|
||||
AUDIO = "audio"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
# Common ElevenLabs voices
|
||||
ELEVENLABS_VOICES = [
|
||||
{"id": "21m00Tcm4TlvDq8ikWAM", "name": "Rachel"},
|
||||
{"id": "AZnzlk1XvdvUeBnXmlld", "name": "Domi"},
|
||||
{"id": "EXAVITQu4vr4xnSDxMaL", "name": "Bella"},
|
||||
{"id": "ErXwobaYiN019PkySvjV", "name": "Antoni"},
|
||||
{"id": "MF3mGyEYCl7XYWbV9V6O", "name": "Elli"},
|
||||
{"id": "TxGEqnHWrfWFTfGW9XjX", "name": "Josh"},
|
||||
{"id": "VR6AewLTigWG4xSOukaG", "name": "Arnold"},
|
||||
{"id": "pNInz6obpgDQGcFmaJgB", "name": "Adam"},
|
||||
{"id": "yoZ06aMxZJJ28mfd3POQ", "name": "Sam"},
|
||||
]
|
||||
|
||||
|
||||
class ElevenLabsStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription session using ElevenLabs Scribe Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "scribe_v2_realtime",
|
||||
input_sample_rate: int = DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate: int = DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code: str = "en",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.target_sample_rate = target_sample_rate
|
||||
self.language_code = language_code
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._final_transcript = ""
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs."""
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: connecting to ElevenLabs API"
|
||||
)
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# VAD is configured via query parameters.
|
||||
# commit_strategy=vad enables automatic transcript commit on silence detection.
|
||||
# These params are part of the ElevenLabs Scribe Realtime API contract:
|
||||
# https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/speech-to-text/realtime"
|
||||
f"?model_id={self.model}"
|
||||
f"&sample_rate={self.target_sample_rate}"
|
||||
f"&language_code={self.language_code}"
|
||||
f"&commit_strategy=vad"
|
||||
f"&vad_silence_threshold_secs={DEFAULT_VAD_SILENCE_THRESHOLD_SECS}"
|
||||
f"&vad_threshold={DEFAULT_VAD_THRESHOLD}"
|
||||
f"&min_speech_duration_ms={DEFAULT_MIN_SPEECH_DURATION_MS}"
|
||||
f"&min_silence_duration_ms={DEFAULT_MIN_SILENCE_DURATION_MS}"
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connecting to {url} "
|
||||
f"(input={self.input_sample_rate}Hz, target={self.target_sample_rate}Hz)"
|
||||
)
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: connected successfully, "
|
||||
f"ws.closed={self._ws.closed}, close_code={self._ws.close_code}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to connect: {e}"
|
||||
)
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
raise
|
||||
|
||||
# Start receiving transcripts in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts from WebSocket."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: receive loop started")
|
||||
if not self._ws:
|
||||
self._logger.warning(
|
||||
"ElevenLabsStreamingTranscriber: no WebSocket connection"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: raw message type: {msg.type}"
|
||||
)
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
parsed_data: Any = None
|
||||
data: dict[str, Any]
|
||||
try:
|
||||
parsed_data = json.loads(msg.data)
|
||||
except json.JSONDecodeError:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: failed to parse JSON: {msg.data[:200]}"
|
||||
)
|
||||
continue
|
||||
if not isinstance(parsed_data, dict):
|
||||
self._logger.error(
|
||||
"ElevenLabsStreamingTranscriber: expected object JSON payload"
|
||||
)
|
||||
continue
|
||||
data = parsed_data
|
||||
|
||||
# ElevenLabs uses message_type field - fail fast if missing
|
||||
if "message_type" not in data and "type" not in data:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: malformed packet missing 'message_type' field: {data}"
|
||||
)
|
||||
continue
|
||||
msg_type = data.get("message_type", data.get("type", ""))
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: received message_type: '{msg_type}', data keys: {list(data.keys())}"
|
||||
)
|
||||
# Check for error in various formats
|
||||
if "error" in data or msg_type == ElevenLabsSTTMessageType.ERROR:
|
||||
error_msg = data.get("error", data.get("message", data))
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: API error: {error_msg}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Handle message types from ElevenLabs Scribe Realtime API.
|
||||
# See https://elevenlabs.io/docs/api-reference/speech-to-text/realtime
|
||||
if msg_type == ElevenLabsSTTMessageType.SESSION_STARTED:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: session started, "
|
||||
f"id={data.get('session_id')}, config={data.get('config')}"
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.PARTIAL_TRANSCRIPT:
|
||||
# Interim result — updated as more audio is processed
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: partial_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT:
|
||||
# Final transcript for the current utterance (VAD detected end)
|
||||
text = data.get("text", "")
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: committed_transcript: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.UTTERANCE_END:
|
||||
# VAD detected end of speech (may carry text or be empty)
|
||||
text = data.get("text", "") or self._final_transcript
|
||||
if text:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: utterance_end: {text[:50]}..."
|
||||
)
|
||||
self._final_transcript = text
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=text, is_vad_end=True)
|
||||
)
|
||||
elif msg_type == ElevenLabsSTTMessageType.SESSION_ENDED:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: session ended"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Log unhandled message types with full data for debugging
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingTranscriber: unhandled message_type: {msg_type}, full data: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
self._logger.debug(
|
||||
f"ElevenLabsStreamingTranscriber: received binary message: {len(msg.data)} bytes"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: WebSocket closed by "
|
||||
f"server, close_code={close_code}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket error: {self._ws.exception() if self._ws else 'N/A'}"
|
||||
)
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSE:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: WebSocket CLOSE frame received, data={msg.data}, extra={msg.extra}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingTranscriber: error in receive loop: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
close_code = self._ws.close_code if self._ws else "N/A"
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingTranscriber: receive loop ended, close_code={close_code}"
|
||||
)
|
||||
await self._transcript_queue.put(None) # Signal end
|
||||
|
||||
def _resample_pcm16(self, data: bytes) -> bytes:
|
||||
"""Resample PCM16 audio from input_sample_rate to target_sample_rate."""
|
||||
import struct
|
||||
|
||||
if self.input_sample_rate == self.target_sample_rate:
|
||||
return data
|
||||
|
||||
# Parse int16 samples
|
||||
num_samples = len(data) // 2
|
||||
samples = list(struct.unpack(f"<{num_samples}h", data))
|
||||
|
||||
# Calculate resampling ratio
|
||||
ratio = self.input_sample_rate / self.target_sample_rate
|
||||
new_length = int(num_samples / ratio)
|
||||
|
||||
# Linear interpolation resampling
|
||||
resampled = []
|
||||
for i in range(new_length):
|
||||
src_idx = i * ratio
|
||||
idx_floor = int(src_idx)
|
||||
idx_ceil = min(idx_floor + 1, num_samples - 1)
|
||||
frac = src_idx - idx_floor
|
||||
sample = int(samples[idx_floor] * (1 - frac) + samples[idx_ceil] * frac)
|
||||
# Clamp to int16 range
|
||||
sample = max(-32768, min(32767, sample))
|
||||
resampled.append(sample)
|
||||
|
||||
return struct.pack(f"<{len(resampled)}h", *resampled)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send an audio chunk for transcription."""
|
||||
if not self._ws:
|
||||
self._logger.warning("send_audio: no WebSocket connection")
|
||||
return
|
||||
if self._closed:
|
||||
self._logger.warning("send_audio: transcriber is closed")
|
||||
return
|
||||
if self._ws.closed:
|
||||
self._logger.warning(
|
||||
f"send_audio: WebSocket is closed, close_code={self._ws.close_code}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Resample from input rate (24kHz) to target rate (16kHz)
|
||||
resampled = self._resample_pcm16(chunk)
|
||||
# ElevenLabs expects input_audio_chunk message format with audio_base_64
|
||||
audio_b64 = base64.b64encode(resampled).decode("utf-8")
|
||||
message = {
|
||||
"message_type": "input_audio_chunk",
|
||||
"audio_base_64": audio_b64,
|
||||
"sample_rate": self.target_sample_rate,
|
||||
}
|
||||
self._logger.info(
|
||||
f"send_audio: {len(chunk)} bytes -> {len(resampled)} bytes (resampled) -> {len(audio_b64)} chars base64"
|
||||
)
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
self._logger.info("send_audio: message sent successfully")
|
||||
except Exception as e:
|
||||
self._logger.error(f"send_audio: failed to send: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript. Returns None when done."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(
|
||||
text="", is_vad_end=False
|
||||
) # No transcript yet, but not done
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close the session and return final transcript."""
|
||||
self._logger.info("ElevenLabsStreamingTranscriber: closing session")
|
||||
self._closed = True
|
||||
if self._ws and not self._ws.closed:
|
||||
try:
|
||||
# Just close the WebSocket - ElevenLabs Scribe doesn't need a special end message
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingTranscriber: closing WebSocket connection"
|
||||
)
|
||||
await self._ws.close()
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error closing WebSocket: {e}")
|
||||
if self._receive_task and not self._receive_task.done():
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
return self._final_transcript
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._final_transcript = ""
|
||||
|
||||
|
||||
class ElevenLabsStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Real-time streaming TTS using ElevenLabs WebSocket API.
|
||||
|
||||
Uses ElevenLabs' stream-input WebSocket which processes text as one
|
||||
continuous stream and returns audio in order.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "eleven_multilingual_v2",
|
||||
output_format: str = "mp3_44100_64",
|
||||
api_base: str | None = None,
|
||||
speed: float = 1.0,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice_id = voice_id
|
||||
self.model_id = model_id
|
||||
self.output_format = output_format
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
self.speed = speed
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to ElevenLabs TTS."""
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# WebSocket URL for streaming input TTS with output format for streaming compatibility
|
||||
# Using mp3_44100_64 for good quality with smaller chunks for real-time playback
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = (
|
||||
f"{ws_base}/v1/text-to-speech/{self.voice_id}/stream-input"
|
||||
f"?model_id={self.model_id}&output_format={self.output_format}"
|
||||
)
|
||||
|
||||
self._ws = await self._session.ws_connect(
|
||||
url,
|
||||
headers={"xi-api-key": self.api_key},
|
||||
)
|
||||
|
||||
# Send initial configuration with generation settings optimized for streaming.
|
||||
# Note: API key is sent via header only (not in body to avoid log exposure).
|
||||
# See https://elevenlabs.io/docs/api-reference/text-to-speech/stream-input
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": " ", # Initial space to start the stream
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": self.speed,
|
||||
},
|
||||
"generation_config": {
|
||||
"chunk_length_schedule": DEFAULT_CHUNK_LENGTH_SCHEDULE,
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Start receiving audio in background
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: connected")
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive audio chunks from WebSocket.
|
||||
|
||||
Audio is returned in order as one continuous stream.
|
||||
"""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if self._closed:
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: closed flag set, stopping "
|
||||
"receive loop"
|
||||
)
|
||||
break
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
# Process audio if present
|
||||
if "audio" in data and data["audio"]:
|
||||
audio_bytes = base64.b64decode(data["audio"])
|
||||
chunk_count += 1
|
||||
total_bytes += len(audio_bytes)
|
||||
await self._audio_queue.put(audio_bytes)
|
||||
|
||||
# Check isFinal separately - a message can have both audio AND isFinal
|
||||
if "isFinal" in data:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: received isFinal={data['isFinal']}, "
|
||||
f"chunks so far: {chunk_count}, bytes: {total_bytes}"
|
||||
)
|
||||
if data.get("isFinal"):
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: isFinal=true, signaling end of audio"
|
||||
)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
# Check for errors
|
||||
if "error" in data or data.get("type") == "error":
|
||||
self._logger.error(
|
||||
f"ElevenLabsStreamingSynthesizer: received error: {data}"
|
||||
)
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
chunk_count += 1
|
||||
total_bytes += len(msg.data)
|
||||
await self._audio_queue.put(msg.data)
|
||||
elif msg.type in (
|
||||
aiohttp.WSMsgType.CLOSE,
|
||||
aiohttp.WSMsgType.ERROR,
|
||||
):
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: WebSocket closed/error, type={msg.type}"
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"ElevenLabsStreamingSynthesizer receive error: {e}")
|
||||
finally:
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: receive loop ended, {chunk_count} chunks, {total_bytes} bytes"
|
||||
)
|
||||
await self._audio_queue.put(None) # Signal end of stream
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send text to be synthesized.
|
||||
|
||||
ElevenLabs processes text as a continuous stream and returns
|
||||
audio in order. We let ElevenLabs handle buffering via chunk_length_schedule
|
||||
and only force generation when flush() is called at the end.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
"""
|
||||
if self._ws and not self._closed and text.strip():
|
||||
self._logger.info(
|
||||
f"ElevenLabsStreamingSynthesizer: sending text ({len(text)} chars): '{text}'"
|
||||
)
|
||||
# Let ElevenLabs buffer and auto-generate based on chunk_length_schedule
|
||||
# Don't trigger generation here - wait for flush() at the end
|
||||
await self._ws.send_str(
|
||||
json.dumps(
|
||||
{
|
||||
"text": text + " ", # Space for natural speech flow
|
||||
}
|
||||
)
|
||||
)
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: text sent successfully")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping send_text - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}, text='{text[:30] if text else ''}'"
|
||||
)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input. ElevenLabs will generate remaining audio and close."""
|
||||
if self._ws and not self._closed:
|
||||
# Send empty string to signal end of input
|
||||
# ElevenLabs will generate any remaining buffered text,
|
||||
# send all audio chunks, send isFinal, then close the connection
|
||||
self._logger.info(
|
||||
"ElevenLabsStreamingSynthesizer: sending end-of-input (empty string)"
|
||||
)
|
||||
await self._ws.send_str(json.dumps({"text": ""}))
|
||||
self._logger.info("ElevenLabsStreamingSynthesizer: end-of-input sent")
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"ElevenLabsStreamingSynthesizer: skipping flush - "
|
||||
f"ws={self._ws is not None}, closed={self._closed}"
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
# Valid ElevenLabs model IDs
|
||||
ELEVENLABS_STT_MODELS = {"scribe_v1", "scribe_v2_realtime"}
|
||||
ELEVENLABS_TTS_MODELS = {
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_monolingual_v1",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_flash_v2",
|
||||
}
|
||||
|
||||
|
||||
class ElevenLabsVoiceProvider(VoiceProviderInterface):
|
||||
"""ElevenLabs voice provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base or DEFAULT_ELEVENLABS_API_BASE
|
||||
# Validate and default models - use valid ElevenLabs model IDs
|
||||
self.stt_model = (
|
||||
stt_model if stt_model in ELEVENLABS_STT_MODELS else "scribe_v1"
|
||||
)
|
||||
self.tts_model = (
|
||||
tts_model
|
||||
if tts_model in ELEVENLABS_TTS_MODELS
|
||||
else "eleven_multilingual_v2"
|
||||
)
|
||||
self.default_voice = default_voice
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using ElevenLabs Speech-to-Text API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Format of the audio (e.g., 'webm', 'mp3', 'wav')
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for transcription")
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
url = f"{self.api_base}/v1/speech-to-text"
|
||||
|
||||
# Map common formats to MIME types
|
||||
mime_types = {
|
||||
"webm": "audio/webm",
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"ogg": "audio/ogg",
|
||||
"flac": "audio/flac",
|
||||
"m4a": "audio/mp4",
|
||||
}
|
||||
mime_type = mime_types.get(audio_format.lower(), f"audio/{audio_format}")
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
}
|
||||
|
||||
# ElevenLabs expects multipart form data
|
||||
form_data = aiohttp.FormData()
|
||||
form_data.add_field(
|
||||
"audio",
|
||||
audio_data,
|
||||
filename=f"audio.{audio_format}",
|
||||
content_type=mime_type,
|
||||
)
|
||||
# For batch STT, use scribe_v1 (not the realtime model)
|
||||
batch_model = (
|
||||
self.stt_model if self.stt_model in ("scribe_v1",) else "scribe_v1"
|
||||
)
|
||||
form_data.add_field("model_id", batch_model)
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs transcribe: sending {len(audio_data)} bytes, format={audio_format}"
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, data=form_data) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs transcribe failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs transcription failed: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
text = result.get("text", "")
|
||||
logger.info(f"ElevenLabs transcribe: got result: {text[:50]}...")
|
||||
return text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using ElevenLabs TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice ID (defaults to provider's default voice or Rachel)
|
||||
speed: Playback speed multiplier
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required for TTS")
|
||||
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
|
||||
url = f"{self.api_base}/v1/text-to-speech/{voice_id}/stream"
|
||||
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: starting synthesis, text='{text[:50]}...', "
|
||||
f"voice={voice_id}, model={self.tts_model}, speed={speed}"
|
||||
)
|
||||
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "audio/mpeg",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"model_id": self.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": DEFAULT_VOICE_STABILITY,
|
||||
"similarity_boost": DEFAULT_VOICE_SIMILARITY_BOOST,
|
||||
"speed": speed,
|
||||
},
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=payload) as response:
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: got response status={response.status}, "
|
||||
f"content-type={response.headers.get('content-type')}"
|
||||
)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"ElevenLabs TTS failed: {error_text}")
|
||||
raise RuntimeError(f"ElevenLabs TTS failed: {error_text}")
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
chunk_count = 0
|
||||
total_bytes = 0
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
chunk_count += 1
|
||||
total_bytes += len(chunk)
|
||||
yield chunk
|
||||
logger.info(
|
||||
f"ElevenLabs TTS: streaming complete, {chunk_count} chunks, "
|
||||
f"{total_bytes} total bytes"
|
||||
)
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate ElevenLabs API key.
|
||||
|
||||
Calls /v1/models as a lightweight check. ElevenLabs returns 401 for
|
||||
both truly invalid keys and valid keys with restricted scopes, so we
|
||||
inspect the response body: a "missing_permissions" status means the
|
||||
key authenticated successfully but lacks a specific scope.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise ValueError("ElevenLabs API key required")
|
||||
|
||||
headers = {"xi-api-key": self.api_key}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"{self.api_base}/v1/models", headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return
|
||||
if response.status in (401, 403):
|
||||
try:
|
||||
body = await response.json()
|
||||
detail = body.get("detail", {})
|
||||
status = (
|
||||
detail.get("status", "") if isinstance(detail, dict) else ""
|
||||
)
|
||||
except Exception:
|
||||
status = ""
|
||||
# "missing_permissions" means the key is valid but
|
||||
# lacks this specific scope — that's fine.
|
||||
if status == "missing_permissions":
|
||||
return
|
||||
raise RuntimeError("Invalid ElevenLabs API key.")
|
||||
raise RuntimeError("ElevenLabs credential validation failed.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Return common ElevenLabs voices."""
|
||||
return ELEVENLABS_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "scribe_v2_realtime", "name": "Scribe v2 Realtime (Streaming)"},
|
||||
{"id": "scribe_v1", "name": "Scribe v1 (Batch)"},
|
||||
]
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"id": "eleven_multilingual_v2", "name": "Multilingual v2"},
|
||||
{"id": "eleven_turbo_v2_5", "name": "Turbo v2.5"},
|
||||
{"id": "eleven_monolingual_v1", "name": "Monolingual v1"},
|
||||
]
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""ElevenLabs supports streaming via Scribe Realtime API."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""ElevenLabs supports real-time streaming TTS via WebSocket."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> ElevenLabsStreamingTranscriber:
|
||||
"""Create a streaming transcription session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
# ElevenLabs realtime STT requires scribe_v2_realtime model.
|
||||
# Frontend sends PCM16 at DEFAULT_INPUT_SAMPLE_RATE (24kHz),
|
||||
# but ElevenLabs expects DEFAULT_TARGET_SAMPLE_RATE (16kHz).
|
||||
# The transcriber resamples automatically.
|
||||
transcriber = ElevenLabsStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model="scribe_v2_realtime",
|
||||
input_sample_rate=DEFAULT_INPUT_SAMPLE_RATE,
|
||||
target_sample_rate=DEFAULT_TARGET_SAMPLE_RATE,
|
||||
language_code="en",
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> ElevenLabsStreamingSynthesizer:
|
||||
"""Create a streaming TTS session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
voice_id = voice or self.default_voice or "21m00Tcm4TlvDq8ikWAM"
|
||||
synthesizer = ElevenLabsStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice_id=voice_id,
|
||||
model_id=self.tts_model,
|
||||
output_format=DEFAULT_TTS_OUTPUT_FORMAT,
|
||||
api_base=self.api_base,
|
||||
speed=speed,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -1,633 +0,0 @@
|
||||
"""OpenAI voice provider for STT and TTS.
|
||||
|
||||
OpenAI supports:
|
||||
- **STT**: Whisper (batch transcription via REST) and Realtime API (streaming
|
||||
transcription via WebSocket with server-side VAD). Audio is sent as base64-encoded
|
||||
PCM16 at 24kHz mono. The Realtime API returns transcript deltas and completed
|
||||
transcription events per VAD-detected utterance.
|
||||
- **TTS**: HTTP streaming endpoint that returns audio chunks progressively.
|
||||
Supported models: tts-1 (standard) and tts-1-hd (high quality).
|
||||
|
||||
See https://platform.openai.com/docs for API reference.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.voice.interface import StreamingSynthesizerProtocol
|
||||
from onyx.voice.interface import StreamingTranscriberProtocol
|
||||
from onyx.voice.interface import TranscriptResult
|
||||
from onyx.voice.interface import VoiceProviderInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Default OpenAI API base URL
|
||||
DEFAULT_OPENAI_API_BASE = "https://api.openai.com"
|
||||
|
||||
|
||||
class OpenAIRealtimeMessageType(StrEnum):
|
||||
"""Message types from OpenAI Realtime transcription API."""
|
||||
|
||||
ERROR = "error"
|
||||
SPEECH_STARTED = "input_audio_buffer.speech_started"
|
||||
SPEECH_STOPPED = "input_audio_buffer.speech_stopped"
|
||||
BUFFER_COMMITTED = "input_audio_buffer.committed"
|
||||
TRANSCRIPTION_DELTA = "conversation.item.input_audio_transcription.delta"
|
||||
TRANSCRIPTION_COMPLETED = "conversation.item.input_audio_transcription.completed"
|
||||
SESSION_CREATED = "transcription_session.created"
|
||||
SESSION_UPDATED = "transcription_session.updated"
|
||||
ITEM_CREATED = "conversation.item.created"
|
||||
|
||||
|
||||
def _http_to_ws_url(http_url: str) -> str:
|
||||
"""Convert http(s) URL to ws(s) URL for WebSocket connections."""
|
||||
if http_url.startswith("https://"):
|
||||
return "wss://" + http_url[8:]
|
||||
elif http_url.startswith("http://"):
|
||||
return "ws://" + http_url[7:]
|
||||
return http_url
|
||||
|
||||
|
||||
class OpenAIStreamingTranscriber(StreamingTranscriberProtocol):
|
||||
"""Streaming transcription using OpenAI Realtime API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "whisper-1",
|
||||
api_base: str | None = None,
|
||||
):
|
||||
# Import logger first
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
|
||||
self._logger.info(
|
||||
f"OpenAIStreamingTranscriber: initializing with model {model}"
|
||||
)
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._transcript_queue: asyncio.Queue[TranscriptResult | None] = asyncio.Queue()
|
||||
self._current_turn_transcript = "" # Transcript for current VAD turn
|
||||
self._accumulated_transcript = "" # Accumulated across all turns
|
||||
self._receive_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection to OpenAI Realtime API."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
|
||||
# OpenAI Realtime transcription endpoint
|
||||
ws_base = _http_to_ws_url(self.api_base.rstrip("/"))
|
||||
url = f"{ws_base}/v1/realtime?intent=transcription"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
try:
|
||||
self._ws = await self._session.ws_connect(url, headers=headers)
|
||||
self._logger.info("Connected to OpenAI Realtime API")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Failed to connect to OpenAI Realtime API: {e}")
|
||||
raise
|
||||
|
||||
# Configure the session for transcription
|
||||
# Enable server-side VAD (Voice Activity Detection) for automatic speech detection
|
||||
config_message = {
|
||||
"type": "transcription_session.update",
|
||||
"session": {
|
||||
"input_audio_format": "pcm16", # 16-bit PCM at 24kHz mono
|
||||
"input_audio_transcription": {
|
||||
"model": self.model,
|
||||
},
|
||||
"turn_detection": {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500,
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send_str(json.dumps(config_message))
|
||||
self._logger.info(f"Sent config for model: {self.model} with server VAD")
|
||||
|
||||
# Start receiving transcripts
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Background task to receive transcripts."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for msg in self._ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
msg_type = data.get("type", "")
|
||||
self._logger.debug(f"Received message type: {msg_type}")
|
||||
|
||||
# Handle errors
|
||||
if msg_type == OpenAIRealtimeMessageType.ERROR:
|
||||
error = data.get("error", {})
|
||||
self._logger.error(f"OpenAI error: {error}")
|
||||
continue
|
||||
|
||||
# Handle VAD events
|
||||
if msg_type == OpenAIRealtimeMessageType.SPEECH_STARTED:
|
||||
self._logger.info("OpenAI: Speech started")
|
||||
# Reset current turn transcript for new speech
|
||||
self._current_turn_transcript = ""
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.SPEECH_STOPPED:
|
||||
self._logger.info(
|
||||
"OpenAI: Speech stopped (VAD detected silence)"
|
||||
)
|
||||
continue
|
||||
elif msg_type == OpenAIRealtimeMessageType.BUFFER_COMMITTED:
|
||||
self._logger.info("OpenAI: Audio buffer committed")
|
||||
continue
|
||||
|
||||
# Handle transcription events
|
||||
if msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA:
|
||||
delta = data.get("delta", "")
|
||||
if delta:
|
||||
self._logger.info(f"OpenAI: Transcription delta: {delta}")
|
||||
self._current_turn_transcript += delta
|
||||
# Show accumulated + current turn transcript
|
||||
full_transcript = self._accumulated_transcript
|
||||
if full_transcript and self._current_turn_transcript:
|
||||
full_transcript += " "
|
||||
full_transcript += self._current_turn_transcript
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(text=full_transcript, is_vad_end=False)
|
||||
)
|
||||
elif msg_type == OpenAIRealtimeMessageType.TRANSCRIPTION_COMPLETED:
|
||||
transcript = data.get("transcript", "")
|
||||
if transcript:
|
||||
self._logger.info(
|
||||
f"OpenAI: Transcription completed (VAD turn end): {transcript[:50]}..."
|
||||
)
|
||||
# This is the final transcript for this VAD turn
|
||||
self._current_turn_transcript = transcript
|
||||
# Accumulate this turn's transcript
|
||||
if self._accumulated_transcript:
|
||||
self._accumulated_transcript += " " + transcript
|
||||
else:
|
||||
self._accumulated_transcript = transcript
|
||||
# Send with is_vad_end=True to trigger auto-send
|
||||
await self._transcript_queue.put(
|
||||
TranscriptResult(
|
||||
text=self._accumulated_transcript,
|
||||
is_vad_end=True,
|
||||
)
|
||||
)
|
||||
elif msg_type not in (
|
||||
OpenAIRealtimeMessageType.SESSION_CREATED,
|
||||
OpenAIRealtimeMessageType.SESSION_UPDATED,
|
||||
OpenAIRealtimeMessageType.ITEM_CREATED,
|
||||
):
|
||||
# Log any other message types we might be missing
|
||||
self._logger.info(
|
||||
f"OpenAI: Unhandled message type '{msg_type}': {data}"
|
||||
)
|
||||
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.error(f"WebSocket error: {self._ws.exception()}")
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
self._logger.info("WebSocket closed by server")
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error in receive loop: {e}")
|
||||
finally:
|
||||
await self._transcript_queue.put(None)
|
||||
|
||||
async def send_audio(self, chunk: bytes) -> None:
|
||||
"""Send audio chunk to OpenAI."""
|
||||
if self._ws and not self._closed:
|
||||
# OpenAI expects base64-encoded PCM16 audio at 24kHz mono
|
||||
# PCM16 at 24kHz: 24000 samples/sec * 2 bytes/sample = 48000 bytes/sec
|
||||
# So chunk_bytes / 48000 = duration in seconds
|
||||
duration_ms = (len(chunk) / 48000) * 1000
|
||||
self._logger.debug(
|
||||
f"Sending {len(chunk)} bytes ({duration_ms:.1f}ms) of audio to OpenAI. "
|
||||
f"First 10 bytes: {chunk[:10].hex() if len(chunk) >= 10 else chunk.hex()}"
|
||||
)
|
||||
message = {
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": base64.b64encode(chunk).decode("utf-8"),
|
||||
}
|
||||
await self._ws.send_str(json.dumps(message))
|
||||
|
||||
def reset_transcript(self) -> None:
|
||||
"""Reset accumulated transcript. Call after auto-send to start fresh."""
|
||||
self._logger.info("OpenAI: Resetting accumulated transcript")
|
||||
self._accumulated_transcript = ""
|
||||
self._current_turn_transcript = ""
|
||||
|
||||
async def receive_transcript(self) -> TranscriptResult | None:
|
||||
"""Receive next transcript."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._transcript_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return TranscriptResult(text="", is_vad_end=False)
|
||||
|
||||
async def close(self) -> str:
|
||||
"""Close session and return final transcript."""
|
||||
self._closed = True
|
||||
if self._ws:
|
||||
# With server VAD, the buffer is auto-committed when speech stops.
|
||||
# But we should still commit any remaining audio and wait for transcription.
|
||||
try:
|
||||
await self._ws.send_str(
|
||||
json.dumps({"type": "input_audio_buffer.commit"})
|
||||
)
|
||||
except Exception as e:
|
||||
self._logger.debug(f"Error sending commit (may be expected): {e}")
|
||||
|
||||
# Wait for *new* transcription to arrive (up to 5 seconds)
|
||||
self._logger.info("Waiting for transcription to complete...")
|
||||
transcript_before_commit = self._accumulated_transcript
|
||||
for _ in range(50): # 50 * 100ms = 5 seconds max
|
||||
await asyncio.sleep(0.1)
|
||||
if self._accumulated_transcript != transcript_before_commit:
|
||||
self._logger.info(
|
||||
f"Got final transcript: {self._accumulated_transcript[:50]}..."
|
||||
)
|
||||
break
|
||||
else:
|
||||
self._logger.warning("Timed out waiting for transcription")
|
||||
|
||||
await self._ws.close()
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
return self._accumulated_transcript
|
||||
|
||||
|
||||
# OpenAI available voices for TTS
|
||||
OPENAI_VOICES = [
|
||||
{"id": "alloy", "name": "Alloy"},
|
||||
{"id": "echo", "name": "Echo"},
|
||||
{"id": "fable", "name": "Fable"},
|
||||
{"id": "onyx", "name": "Onyx"},
|
||||
{"id": "nova", "name": "Nova"},
|
||||
{"id": "shimmer", "name": "Shimmer"},
|
||||
]
|
||||
|
||||
# OpenAI available STT models (all support streaming via Realtime API)
|
||||
OPENAI_STT_MODELS = [
|
||||
{"id": "whisper-1", "name": "Whisper v1"},
|
||||
{"id": "gpt-4o-transcribe", "name": "GPT-4o Transcribe"},
|
||||
{"id": "gpt-4o-mini-transcribe", "name": "GPT-4o Mini Transcribe"},
|
||||
]
|
||||
|
||||
# OpenAI available TTS models
|
||||
OPENAI_TTS_MODELS = [
|
||||
{"id": "tts-1", "name": "TTS-1 (Standard)"},
|
||||
{"id": "tts-1-hd", "name": "TTS-1 HD (High Quality)"},
|
||||
]
|
||||
|
||||
|
||||
def _create_wav_header(
|
||||
data_length: int,
|
||||
sample_rate: int = 24000,
|
||||
channels: int = 1,
|
||||
bits_per_sample: int = 16,
|
||||
) -> bytes:
|
||||
"""Create a WAV file header for PCM audio data."""
|
||||
import struct
|
||||
|
||||
byte_rate = sample_rate * channels * bits_per_sample // 8
|
||||
block_align = channels * bits_per_sample // 8
|
||||
|
||||
# WAV header is 44 bytes
|
||||
header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # ChunkID
|
||||
36 + data_length, # ChunkSize
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1ID
|
||||
16, # Subchunk1Size (PCM)
|
||||
1, # AudioFormat (1 = PCM)
|
||||
channels, # NumChannels
|
||||
sample_rate, # SampleRate
|
||||
byte_rate, # ByteRate
|
||||
block_align, # BlockAlign
|
||||
bits_per_sample, # BitsPerSample
|
||||
b"data", # Subchunk2ID
|
||||
data_length, # Subchunk2Size
|
||||
)
|
||||
return header
|
||||
|
||||
|
||||
class OpenAIStreamingSynthesizer(StreamingSynthesizerProtocol):
|
||||
"""Streaming TTS using OpenAI HTTP TTS API with streaming responses."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
voice: str = "alloy",
|
||||
model: str = "tts-1",
|
||||
speed: float = 1.0,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
self._logger = setup_logger()
|
||||
self.api_key = api_key
|
||||
self.voice = voice
|
||||
self.model = model
|
||||
self.speed = max(0.25, min(4.0, speed))
|
||||
self.api_base = api_base or DEFAULT_OPENAI_API_BASE
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._text_queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
self._synthesis_task: asyncio.Task | None = None
|
||||
self._closed = False
|
||||
self._flushed = False
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session for TTS requests."""
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connecting")
|
||||
self._session = aiohttp.ClientSession()
|
||||
# Start background task to process text queue
|
||||
self._synthesis_task = asyncio.create_task(self._process_text_queue())
|
||||
self._logger.info("OpenAIStreamingSynthesizer: connected")
|
||||
|
||||
async def _process_text_queue(self) -> None:
|
||||
"""Background task to process queued text for synthesis."""
|
||||
while not self._closed:
|
||||
try:
|
||||
text = await asyncio.wait_for(self._text_queue.get(), timeout=0.1)
|
||||
if text is None:
|
||||
break
|
||||
await self._synthesize_text(text)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error processing text queue: {e}")
|
||||
|
||||
async def _synthesize_text(self, text: str) -> None:
|
||||
"""Make HTTP TTS request and stream audio to queue."""
|
||||
if not self._session or self._closed:
|
||||
return
|
||||
|
||||
url = f"{self.api_base.rstrip('/')}/v1/audio/speech"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"voice": self.voice,
|
||||
"input": text,
|
||||
"speed": self.speed,
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url, headers=headers, json=payload
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
self._logger.error(f"OpenAI TTS error: {error_text}")
|
||||
return
|
||||
|
||||
# Use 8192 byte chunks for smoother streaming
|
||||
# (larger chunks = more complete MP3 frames, better playback)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if self._closed:
|
||||
break
|
||||
if chunk:
|
||||
await self._audio_queue.put(chunk)
|
||||
except Exception as e:
|
||||
self._logger.error(f"OpenAIStreamingSynthesizer synthesis error: {e}")
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Queue text to be synthesized via HTTP streaming."""
|
||||
if not text.strip() or self._closed:
|
||||
return
|
||||
await self._text_queue.put(text)
|
||||
|
||||
async def receive_audio(self) -> bytes | None:
|
||||
"""Receive next audio chunk (MP3 format)."""
|
||||
try:
|
||||
return await asyncio.wait_for(self._audio_queue.get(), timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
return b"" # No audio yet, but not done
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Signal end of text input - wait for synthesis to complete."""
|
||||
if self._flushed:
|
||||
return
|
||||
self._flushed = True
|
||||
|
||||
# Signal end of text input
|
||||
await self._text_queue.put(None)
|
||||
|
||||
# Wait for synthesis task to complete processing all text
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
try:
|
||||
await asyncio.wait_for(self._synthesis_task, timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._logger.warning("OpenAIStreamingSynthesizer: flush timeout")
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Signal end of audio stream
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
|
||||
# Signal end of queues only if flush wasn't already called
|
||||
if not self._flushed:
|
||||
await self._text_queue.put(None)
|
||||
await self._audio_queue.put(None)
|
||||
|
||||
if self._synthesis_task and not self._synthesis_task.done():
|
||||
self._synthesis_task.cancel()
|
||||
try:
|
||||
await self._synthesis_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProviderInterface):
|
||||
"""OpenAI voice provider using Whisper for STT and TTS API for speech synthesis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
api_base: str | None = None,
|
||||
stt_model: str | None = None,
|
||||
tts_model: str | None = None,
|
||||
default_voice: str | None = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.stt_model = stt_model or "whisper-1"
|
||||
self.tts_model = tts_model or "tts-1"
|
||||
self.default_voice = default_voice or "alloy"
|
||||
|
||||
self._client: "AsyncOpenAI | None" = None
|
||||
|
||||
def _get_client(self) -> "AsyncOpenAI":
|
||||
if self._client is None:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.api_base,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def transcribe(self, audio_data: bytes, audio_format: str) -> str:
|
||||
"""
|
||||
Transcribe audio using OpenAI Whisper.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
audio_format: Audio format (e.g., "webm", "wav", "mp3")
|
||||
|
||||
Returns:
|
||||
Transcribed text
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Create a file-like object from the audio bytes
|
||||
audio_file = io.BytesIO(audio_data)
|
||||
audio_file.name = f"audio.{audio_format}"
|
||||
|
||||
response = await client.audio.transcriptions.create(
|
||||
model=self.stt_model,
|
||||
file=audio_file,
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
async def synthesize_stream(
|
||||
self, text: str, voice: str | None = None, speed: float = 1.0
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Convert text to audio using OpenAI TTS with streaming.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice identifier (defaults to provider's default voice)
|
||||
speed: Playback speed multiplier (0.25 to 4.0)
|
||||
|
||||
Yields:
|
||||
Audio data chunks (mp3 format)
|
||||
"""
|
||||
client = self._get_client()
|
||||
|
||||
# Clamp speed to valid range
|
||||
speed = max(0.25, min(4.0, speed))
|
||||
|
||||
# Use with_streaming_response for proper async streaming
|
||||
# Using 8192 byte chunks for better streaming performance
|
||||
# (larger chunks = fewer round-trips, more complete MP3 frames)
|
||||
async with client.audio.speech.with_streaming_response.create(
|
||||
model=self.tts_model,
|
||||
voice=voice or self.default_voice,
|
||||
input=text,
|
||||
speed=speed,
|
||||
response_format="mp3",
|
||||
) as response:
|
||||
async for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
async def validate_credentials(self) -> None:
|
||||
"""Validate OpenAI API key by listing models."""
|
||||
from openai import AuthenticationError, PermissionDeniedError
|
||||
|
||||
client = self._get_client()
|
||||
try:
|
||||
await client.models.list()
|
||||
except AuthenticationError:
|
||||
raise RuntimeError("Invalid OpenAI API key.")
|
||||
except PermissionDeniedError:
|
||||
raise RuntimeError("OpenAI API key does not have sufficient permissions.")
|
||||
|
||||
def get_available_voices(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS voices."""
|
||||
return OPENAI_VOICES.copy()
|
||||
|
||||
def get_available_stt_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI STT models."""
|
||||
return OPENAI_STT_MODELS.copy()
|
||||
|
||||
def get_available_tts_models(self) -> list[dict[str, str]]:
|
||||
"""Get available OpenAI TTS models."""
|
||||
return OPENAI_TTS_MODELS.copy()
|
||||
|
||||
def supports_streaming_stt(self) -> bool:
|
||||
"""OpenAI supports streaming via Realtime API for all STT models."""
|
||||
return True
|
||||
|
||||
def supports_streaming_tts(self) -> bool:
|
||||
"""OpenAI supports real-time streaming TTS via Realtime API."""
|
||||
return True
|
||||
|
||||
async def create_streaming_transcriber(
|
||||
self, _audio_format: str = "webm"
|
||||
) -> OpenAIStreamingTranscriber:
|
||||
"""Create a streaming transcription session using Realtime API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming transcription")
|
||||
transcriber = OpenAIStreamingTranscriber(
|
||||
api_key=self.api_key,
|
||||
model=self.stt_model,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await transcriber.connect()
|
||||
return transcriber
|
||||
|
||||
async def create_streaming_synthesizer(
|
||||
self, voice: str | None = None, speed: float = 1.0
|
||||
) -> OpenAIStreamingSynthesizer:
|
||||
"""Create a streaming TTS session using HTTP streaming API."""
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for streaming TTS")
|
||||
synthesizer = OpenAIStreamingSynthesizer(
|
||||
api_key=self.api_key,
|
||||
voice=voice or self.default_voice or "alloy",
|
||||
model=self.tts_model or "tts-1",
|
||||
speed=speed,
|
||||
api_base=self.api_base,
|
||||
)
|
||||
await synthesizer.connect()
|
||||
return synthesizer
|
||||
@@ -67,8 +67,6 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -616,7 +614,7 @@ opentelemetry-sdk==1.39.1
|
||||
# opentelemetry-exporter-otlp-proto-http
|
||||
opentelemetry-semantic-conventions==0.60b1
|
||||
# via opentelemetry-sdk
|
||||
orjson==3.11.6 ; platform_python_implementation != 'PyPy'
|
||||
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
|
||||
# via langsmith
|
||||
packaging==24.2
|
||||
# via
|
||||
@@ -752,7 +750,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.8.0
|
||||
pypdf==6.7.5
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -1022,7 +1020,7 @@ toolz==1.1.0
|
||||
# dask
|
||||
# distributed
|
||||
# partd
|
||||
tornado==6.5.5
|
||||
tornado==6.5.2
|
||||
# via distributed
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.0
|
||||
onyx-devtools==0.6.3
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -406,7 +406,7 @@ referencing==0.36.2
|
||||
# jsonschema-specifications
|
||||
regex==2025.11.3
|
||||
# via tiktoken
|
||||
release-tag==0.5.2
|
||||
release-tag==0.4.3
|
||||
# via onyx
|
||||
reorder-python-imports-black==3.14.0
|
||||
# via onyx
|
||||
@@ -466,7 +466,7 @@ tokenizers==0.21.4
|
||||
# via
|
||||
# cohere
|
||||
# litellm
|
||||
tornado==6.5.5
|
||||
tornado==6.5.2
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
|
||||
@@ -19,7 +19,7 @@ from fastapi.testclient import TestClient
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.main import get_application
|
||||
from onyx.main import fetch_versioned_implementation
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -51,8 +51,11 @@ def client() -> Generator[TestClient, None, None]:
|
||||
# Patch out prometheus metrics setup to avoid "Duplicated timeseries in
|
||||
# CollectorRegistry" errors when multiple tests each create a new app
|
||||
# (prometheus registers metrics globally and rejects duplicate names).
|
||||
get_app = fetch_versioned_implementation(
|
||||
module="onyx.main", attribute="get_application"
|
||||
)
|
||||
with patch("onyx.main.setup_prometheus_metrics"):
|
||||
app: FastAPI = get_application(lifespan_override=test_lifespan)
|
||||
app: FastAPI = get_app(lifespan_override=test_lifespan)
|
||||
|
||||
# Override the database session dependency with a mock
|
||||
# (these tests don't actually need DB access)
|
||||
|
||||
@@ -1,398 +0,0 @@
|
||||
"""External dependency tests for the old DocumentIndex interface.
|
||||
|
||||
These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_available() -> Generator[None, None, None]:
|
||||
"""Verifies OpenSearch is running, fails the test if not."""
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
pytest.fail("OpenSearch is not available.")
|
||||
yield # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def test_index_name() -> Generator[str, None, None]:
|
||||
yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
"""Sets up tenant context for testing."""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield # Test runs here.
|
||||
finally:
|
||||
# Reset the tenant context after the test
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def httpx_client() -> Generator[httpx.Client, None, None]:
|
||||
client = get_vespa_http_client()
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vespa_document_index(
|
||||
httpx_client: httpx.Client,
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[VespaIndex, None, None]:
|
||||
vespa_index = VespaIndex(
|
||||
index_name=test_index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
backend_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..")
|
||||
)
|
||||
with patch("os.getcwd", return_value=backend_dir):
|
||||
vespa_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
# Verify Vespa is running, fails the test if not. Try 90 seconds for testing
|
||||
# in CI. We have to do this here because this endpoint only becomes live
|
||||
# once we create an index.
|
||||
if not wait_for_vespa_with_timeout(wait_limit=90):
|
||||
pytest.fail("Vespa is not available.")
|
||||
|
||||
# Wait until the schema is actually ready for writes on content nodes. We
|
||||
# probe by attempting a PUT; 200 means the schema is live, 400 means not
|
||||
# yet. This is so scuffed but running the test is really flakey otherwise;
|
||||
# this is only temporary until we entirely move off of Vespa.
|
||||
probe_doc = {
|
||||
"fields": {
|
||||
"document_id": "__probe__",
|
||||
"chunk_id": 0,
|
||||
"blurb": "",
|
||||
"title": "",
|
||||
"skip_title": True,
|
||||
"content": "",
|
||||
"content_summary": "",
|
||||
"source_type": "file",
|
||||
"source_links": "null",
|
||||
"semantic_identifier": "",
|
||||
"section_continuation": False,
|
||||
"large_chunk_reference_ids": [],
|
||||
"metadata": "{}",
|
||||
"metadata_list": [],
|
||||
"metadata_suffix": "",
|
||||
"chunk_context": "",
|
||||
"doc_summary": "",
|
||||
"embeddings": {"full_chunk": [1.0] + [0.0] * 127},
|
||||
"access_control_list": {},
|
||||
"document_sets": {},
|
||||
"image_file_name": None,
|
||||
"user_project": [],
|
||||
"personas": [],
|
||||
"boost": 0.0,
|
||||
"aggregated_chunk_boost_factor": 0.0,
|
||||
"primary_owners": [],
|
||||
"secondary_owners": [],
|
||||
}
|
||||
}
|
||||
schema_ready = False
|
||||
probe_url = (
|
||||
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
|
||||
)
|
||||
for _ in range(60):
|
||||
resp = httpx_client.post(probe_url, json=probe_doc)
|
||||
if resp.status_code == 200:
|
||||
schema_ready = True
|
||||
# Clean up the probe document.
|
||||
httpx_client.delete(probe_url)
|
||||
break
|
||||
time.sleep(1)
|
||||
if not schema_ready:
|
||||
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
|
||||
|
||||
yield vespa_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_document_index(
|
||||
opensearch_available: None, # noqa: ARG001
|
||||
tenant_context: None, # noqa: ARG001
|
||||
test_index_name: str,
|
||||
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
|
||||
opensearch_index = OpenSearchOldDocumentIndex(
|
||||
index_name=test_index_name,
|
||||
embedding_dim=128,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_name=None,
|
||||
secondary_embedding_dim=None,
|
||||
secondary_embedding_precision=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
opensearch_index.ensure_indices_exist(
|
||||
primary_embedding_dim=128,
|
||||
primary_embedding_precision=EmbeddingPrecision.FLOAT,
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
yield opensearch_index # Test runs here.
|
||||
|
||||
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
|
||||
# pressing; in CI we should be using fresh instances of dependencies each
|
||||
# time anyway.
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def document_indices(
|
||||
vespa_document_index: VespaIndex,
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex,
|
||||
) -> Generator[list[DocumentIndex], None, None]:
|
||||
# Ideally these are parametrized; doing so with pytest fixtures is tricky.
|
||||
yield [opensearch_document_index, vespa_document_index] # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def chunks(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[list[DocMetadataAwareIndexChunk], None, None]:
|
||||
result = []
|
||||
chunk_count = 5
|
||||
doc_id = "test_doc"
|
||||
tenant_id = get_current_tenant_id()
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
document_sets: set[str] = set()
|
||||
user_project: list[int] = list()
|
||||
personas: list[int] = list()
|
||||
boost = 0
|
||||
blurb = "blurb"
|
||||
content = "content"
|
||||
title_prefix = ""
|
||||
doc_summary = ""
|
||||
chunk_context = ""
|
||||
title_embedding = [1.0] + [0] * 127
|
||||
# Full 0 vectors are not supported for cos similarity.
|
||||
embeddings = ChunkEmbedding(
|
||||
full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[]
|
||||
)
|
||||
source_document = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="semantic identifier",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[],
|
||||
metadata={},
|
||||
title="title",
|
||||
)
|
||||
metadata_suffix_keyword = ""
|
||||
image_file_id = None
|
||||
source_links: dict[int, str] = {0: ""}
|
||||
ancestor_hierarchy_node_ids: list[int] = []
|
||||
for i in range(chunk_count):
|
||||
result.append(
|
||||
DocMetadataAwareIndexChunk(
|
||||
tenant_id=tenant_id,
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=0,
|
||||
ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids,
|
||||
embeddings=embeddings,
|
||||
title_embedding=title_embedding,
|
||||
source_document=source_document,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
metadata_suffix_semantic="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
chunk_id=i,
|
||||
blurb=blurb,
|
||||
content=content,
|
||||
source_links=source_links,
|
||||
image_file_id=image_file_id,
|
||||
section_continuation=False,
|
||||
)
|
||||
)
|
||||
yield result # Test runs here.
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def index_batch_params(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[IndexBatchParams, None, None]:
|
||||
# WARNING: doc_id_to_previous_chunk_cnt={"test_doc": 0} is hardcoded to 0,
|
||||
# which is only correct on the very first index call. The document_indices
|
||||
# fixture is scope="module", meaning the same OpenSearch and Vespa backends
|
||||
# persist across all test functions in this module. When a second test
|
||||
# function uses this fixture and calls document_index.index(...), the
|
||||
# backend already has 5 chunks for "test_doc" from the previous test run,
|
||||
# but the batch params still claim 0 prior chunks exist. This can lead to
|
||||
# orphaned/duplicate chunks that make subsequent assertions incorrect.
|
||||
# TODO: Whenever adding a second test, either change this or cleanup the
|
||||
# index between test cases.
|
||||
yield IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc": 5},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentIndexOld:
|
||||
"""Tests the old DocumentIndex interface."""
|
||||
|
||||
def test_update_single_can_clear_user_projects_and_personas(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
# This test case assumes all these chunks correspond to one document.
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> None:
|
||||
"""
|
||||
Tests that update_single can clear user_projects and personas.
|
||||
"""
|
||||
for document_index in document_indices:
|
||||
# Precondition.
|
||||
# Ensure there is some non-empty value for user project and
|
||||
# personas.
|
||||
for chunk in chunks:
|
||||
chunk.user_project = [1]
|
||||
chunk.personas = [2]
|
||||
document_index.index(chunks, index_batch_params)
|
||||
|
||||
# Ensure that we can get chunks as expected with filters.
|
||||
doc_id = chunks[0].source_document.id
|
||||
chunk_count = len(chunks)
|
||||
tenant_id = get_current_tenant_id()
|
||||
# We need to specify the chunk index range and specify
|
||||
# batch_retrieval=True below to trigger the codepath for Vespa's
|
||||
# search API, which uses the expected additive filtering for
|
||||
# project_id and persona_id. Otherwise we would use the codepath for
|
||||
# the visit API, which does not have this kind of filtering
|
||||
# implemented.
|
||||
chunk_request = VespaChunkRequest(
|
||||
document_id=doc_id, min_chunk_ind=0, max_chunk_ind=chunk_count - 1
|
||||
)
|
||||
project_persona_filters = IndexFilters(
|
||||
access_control_list=None,
|
||||
tenant_id=tenant_id,
|
||||
project_id=1,
|
||||
persona_id=2,
|
||||
# We need this even though none of the chunks belong to a
|
||||
# document set because project_id and persona_id are only
|
||||
# additive filters in the event the agent has knowledge scope;
|
||||
# if the agent does not, it is implied that it can see
|
||||
# everything it is allowed to.
|
||||
document_set=["1"],
|
||||
)
|
||||
# Not best practice here but the API for refreshing the index to
|
||||
# ensure that the latest data is present is not exposed in this
|
||||
# class and is not the same for Vespa and OpenSearch, so we just
|
||||
# tolerate a sleep for now. As a consequence the number of tests in
|
||||
# this suite should be small. We only need to tolerate this for as
|
||||
# long as we continue to use Vespa, we can consider exposing
|
||||
# something for OpenSearch later.
|
||||
time.sleep(1)
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request],
|
||||
filters=project_persona_filters,
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == chunk_count
|
||||
# Sort by chunk id to easily test if we have all chunks.
|
||||
for i, inference_chunk in enumerate(
|
||||
sorted(inference_chunks, key=lambda x: x.chunk_id)
|
||||
):
|
||||
assert inference_chunk.chunk_id == i
|
||||
assert inference_chunk.document_id == doc_id
|
||||
|
||||
# Under test.
|
||||
# Explicitly set empty fields here.
|
||||
user_fields = VespaDocumentUserFields(user_projects=[], personas=[])
|
||||
document_index.update_single(
|
||||
doc_id=doc_id,
|
||||
chunk_count=chunk_count,
|
||||
tenant_id=tenant_id,
|
||||
fields=None,
|
||||
user_fields=user_fields,
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
filters = IndexFilters(access_control_list=None, tenant_id=tenant_id)
|
||||
# We should expect to get back all expected chunks with no filters.
|
||||
# Again, not best practice here.
|
||||
time.sleep(1)
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request], filters=filters, batch_retrieval=True
|
||||
)
|
||||
assert len(inference_chunks) == chunk_count
|
||||
# Sort by chunk id to easily test if we have all chunks.
|
||||
for i, inference_chunk in enumerate(
|
||||
sorted(inference_chunks, key=lambda x: x.chunk_id)
|
||||
):
|
||||
assert inference_chunk.chunk_id == i
|
||||
assert inference_chunk.document_id == doc_id
|
||||
# Now, we should expect to not get any chunks if we specify the user
|
||||
# project and personas filters.
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=[chunk_request],
|
||||
filters=project_persona_filters,
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
@@ -17,9 +17,6 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.tasks import (
|
||||
is_continuation_token_done_for_all_slices,
|
||||
)
|
||||
@@ -239,8 +236,6 @@ def full_deployment_setup() -> Generator[None, None, None]:
|
||||
NOTE: We deliberately duplicate this logic from
|
||||
backend/tests/external_dependency_unit/conftest.py because we need to set
|
||||
opensearch_available just for this module, not the entire test session.
|
||||
|
||||
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
|
||||
"""
|
||||
# Patch ENABLE_OPENSEARCH_INDEXING_FOR_ONYX just for this test because we
|
||||
# don't yet want that enabled for all tests.
|
||||
@@ -325,15 +320,9 @@ def test_embedding_dimension(db_session: Session) -> Generator[int, None, None]:
|
||||
@pytest.fixture(scope="function")
|
||||
def patch_get_vespa_chunks_page_size() -> Generator[int, None, None]:
|
||||
test_page_size = 5
|
||||
with (
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
test_page_size,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
test_page_size,
|
||||
),
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
test_page_size,
|
||||
):
|
||||
yield test_page_size # Test runs here.
|
||||
|
||||
@@ -593,175 +582,6 @@ class TestMigrateChunksFromVespaToOpenSearchTask:
|
||||
document_chunks[document.id][opensearch_chunk.chunk_index],
|
||||
)
|
||||
|
||||
def test_chunk_migration_visits_all_chunks_even_when_batch_size_varies(
|
||||
self,
|
||||
db_session: Session,
|
||||
test_documents: list[Document],
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
opensearch_client: OpenSearchIndexClient,
|
||||
test_embedding_dimension: int,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""
|
||||
Tests that chunk migration works correctly even when the batch size
|
||||
changes halfway through a migration.
|
||||
|
||||
Simulates task time running out my mocking the locking behavior.
|
||||
"""
|
||||
# Precondition.
|
||||
# Index chunks into Vespa.
|
||||
document_chunks: dict[str, list[dict[str, Any]]] = {
|
||||
document.id: [
|
||||
_create_raw_document_chunk(
|
||||
document_id=document.id,
|
||||
chunk_index=i,
|
||||
content=f"Test content {i} for {document.id}",
|
||||
embedding=_generate_test_vector(test_embedding_dimension),
|
||||
now=datetime.now(),
|
||||
title=f"Test title {document.id}",
|
||||
title_embedding=_generate_test_vector(test_embedding_dimension),
|
||||
)
|
||||
for i in range(CHUNK_COUNT)
|
||||
]
|
||||
for document in test_documents
|
||||
}
|
||||
all_chunks: list[dict[str, Any]] = []
|
||||
for chunks in document_chunks.values():
|
||||
all_chunks.extend(chunks)
|
||||
vespa_document_index.index_raw_chunks(all_chunks)
|
||||
|
||||
# Run the initial batch. To simulate partial progress we will mock the
|
||||
# redis lock to return True for the first invocation of .owned() and
|
||||
# False subsequently.
|
||||
# NOTE: The batch size is currently set to 5 in
|
||||
# patch_get_vespa_chunks_page_size.
|
||||
mock_redis_client = Mock()
|
||||
mock_lock = Mock()
|
||||
mock_lock.owned.side_effect = [True, False, False]
|
||||
mock_lock.acquire.return_value = True
|
||||
mock_redis_client.lock.return_value = mock_lock
|
||||
with patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client",
|
||||
return_value=mock_redis_client,
|
||||
):
|
||||
result_1 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
)
|
||||
|
||||
assert result_1 is True
|
||||
# Expire the session cache to see the committed changes from the task.
|
||||
db_session.expire_all()
|
||||
|
||||
# Verify partial progress was saved.
|
||||
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
assert tenant_record is not None
|
||||
partial_chunks_migrated = tenant_record.total_chunks_migrated
|
||||
assert partial_chunks_migrated > 0
|
||||
# page_size applies per slice, so one iteration can fetch up to
|
||||
# page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total.
|
||||
assert partial_chunks_migrated <= 5 * GET_VESPA_CHUNKS_SLICE_COUNT
|
||||
assert tenant_record.vespa_visit_continuation_token is not None
|
||||
# Slices are not necessarily evenly distributed across all document
|
||||
# chunks so we can't test that every token is non-None, but certainly at
|
||||
# least one must be.
|
||||
assert any(json.loads(tenant_record.vespa_visit_continuation_token).values())
|
||||
assert tenant_record.migration_completed_at is None
|
||||
assert tenant_record.approx_chunk_count_in_vespa is not None
|
||||
|
||||
# Under test.
|
||||
# Now patch the batch size to be some other number, like 2.
|
||||
mock_redis_client = Mock()
|
||||
mock_lock = Mock()
|
||||
mock_lock.owned.side_effect = [True, False, False]
|
||||
mock_lock.acquire.return_value = True
|
||||
mock_redis_client.lock.return_value = mock_lock
|
||||
with (
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
2,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
2,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.get_redis_client",
|
||||
return_value=mock_redis_client,
|
||||
),
|
||||
):
|
||||
result_2 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert result_2 is True
|
||||
# Expire the session cache to see the committed changes from the task.
|
||||
db_session.expire_all()
|
||||
|
||||
# Verify next partial progress was saved.
|
||||
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
assert tenant_record is not None
|
||||
new_partial_chunks_migrated = tenant_record.total_chunks_migrated
|
||||
assert new_partial_chunks_migrated > partial_chunks_migrated
|
||||
# page_size applies per slice, so one iteration can fetch up to
|
||||
# page_size * GET_VESPA_CHUNKS_SLICE_COUNT chunks total.
|
||||
assert new_partial_chunks_migrated <= (5 + 2) * GET_VESPA_CHUNKS_SLICE_COUNT
|
||||
assert tenant_record.vespa_visit_continuation_token is not None
|
||||
# Slices are not necessarily evenly distributed across all document
|
||||
# chunks so we can't test that every token is non-None, but certainly at
|
||||
# least one must be.
|
||||
assert any(json.loads(tenant_record.vespa_visit_continuation_token).values())
|
||||
assert tenant_record.migration_completed_at is None
|
||||
assert tenant_record.approx_chunk_count_in_vespa is not None
|
||||
|
||||
# Under test.
|
||||
# Run the remainder of the migration.
|
||||
with (
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.tasks.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
2,
|
||||
),
|
||||
patch(
|
||||
"onyx.background.celery.tasks.opensearch_migration.constants.GET_VESPA_CHUNKS_PAGE_SIZE",
|
||||
2,
|
||||
),
|
||||
):
|
||||
result_3 = migrate_chunks_from_vespa_to_opensearch_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert result_3 is True
|
||||
# Expire the session cache to see the committed changes from the task.
|
||||
db_session.expire_all()
|
||||
|
||||
# Verify completion.
|
||||
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
assert tenant_record is not None
|
||||
assert tenant_record.total_chunks_migrated > new_partial_chunks_migrated
|
||||
assert tenant_record.total_chunks_migrated == len(all_chunks)
|
||||
# Visit is complete so continuation token should be None.
|
||||
assert tenant_record.vespa_visit_continuation_token is not None
|
||||
assert is_continuation_token_done_for_all_slices(
|
||||
json.loads(tenant_record.vespa_visit_continuation_token)
|
||||
)
|
||||
assert tenant_record.migration_completed_at is not None
|
||||
assert tenant_record.approx_chunk_count_in_vespa == len(all_chunks)
|
||||
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document in test_documents:
|
||||
opensearch_chunks = _get_document_chunks_from_opensearch(
|
||||
opensearch_client, document.id, get_current_tenant_id()
|
||||
)
|
||||
assert len(opensearch_chunks) == CHUNK_COUNT
|
||||
opensearch_chunks.sort(key=lambda x: x.chunk_index)
|
||||
for opensearch_chunk in opensearch_chunks:
|
||||
_assert_chunk_matches_vespa_chunk(
|
||||
opensearch_chunk,
|
||||
document_chunks[document.id][opensearch_chunk.chunk_index],
|
||||
)
|
||||
|
||||
def test_chunk_migration_empty_vespa(
|
||||
self,
|
||||
db_session: Session,
|
||||
|
||||
@@ -1,256 +0,0 @@
|
||||
"""Unit tests for webapp proxy path rewriting/injection."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.server.features.build.api import api
|
||||
from onyx.server.features.build.api.api import _inject_hmr_fixer
|
||||
from onyx.server.features.build.api.api import _rewrite_asset_paths
|
||||
from onyx.server.features.build.api.api import _rewrite_proxy_response_headers
|
||||
|
||||
SESSION_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
BASE = f"/api/build/sessions/{SESSION_ID}/webapp"
|
||||
|
||||
|
||||
def rewrite(html: str) -> str:
|
||||
return _rewrite_asset_paths(html.encode(), SESSION_ID).decode()
|
||||
|
||||
|
||||
def inject(html: str) -> str:
|
||||
return _inject_hmr_fixer(html.encode(), SESSION_ID).decode()
|
||||
|
||||
|
||||
class TestNextjsPathRewriting:
|
||||
def test_rewrites_bare_next_script_src(self) -> None:
|
||||
html = '<script src="/_next/static/chunks/main.js">'
|
||||
result = rewrite(html)
|
||||
assert f'src="{BASE}/_next/static/chunks/main.js"' in result
|
||||
assert '"/_next/' not in result
|
||||
|
||||
def test_rewrites_bare_next_in_single_quotes(self) -> None:
|
||||
html = "<link href='/_next/static/css/app.css'>"
|
||||
result = rewrite(html)
|
||||
assert f"'{BASE}/_next/static/css/app.css'" in result
|
||||
|
||||
def test_rewrites_bare_next_in_url_parens(self) -> None:
|
||||
html = "background: url(/_next/static/media/font.woff2)"
|
||||
result = rewrite(html)
|
||||
assert f"url({BASE}/_next/static/media/font.woff2)" in result
|
||||
|
||||
def test_no_double_prefix_when_already_proxied(self) -> None:
|
||||
"""assetPrefix makes Next.js emit already-prefixed URLs — must not double-rewrite."""
|
||||
already_prefixed = f'<script src="{BASE}/_next/static/chunks/main.js">'
|
||||
result = rewrite(already_prefixed)
|
||||
# Should be unchanged
|
||||
assert result == already_prefixed
|
||||
# Specifically, no double path
|
||||
assert f"{BASE}/{BASE}" not in result
|
||||
|
||||
def test_rewrites_favicon(self) -> None:
|
||||
html = '<link rel="icon" href="/favicon.ico">'
|
||||
result = rewrite(html)
|
||||
assert f'"{BASE}/favicon.ico"' in result
|
||||
|
||||
def test_rewrites_json_data_path_double_quoted(self) -> None:
|
||||
html = 'fetch("/data/tickets.json")'
|
||||
result = rewrite(html)
|
||||
assert f'"{BASE}/data/tickets.json"' in result
|
||||
|
||||
def test_rewrites_json_data_path_single_quoted(self) -> None:
|
||||
html = "fetch('/data/items.json')"
|
||||
result = rewrite(html)
|
||||
assert f"'{BASE}/data/items.json'" in result
|
||||
|
||||
def test_rewrites_escaped_next_font_path_in_json_script(self) -> None:
|
||||
"""Next dev can embed font asset paths in JSON-escaped script payloads."""
|
||||
html = r'{"src":"\/_next\/static\/media\/font.woff2"}'
|
||||
result = rewrite(html)
|
||||
assert (
|
||||
r'{"src":"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"}'
|
||||
in result
|
||||
)
|
||||
|
||||
def test_rewrites_escaped_next_font_path_in_style_payload(self) -> None:
|
||||
"""Keep dynamically generated next/font URLs inside the session proxy."""
|
||||
html = r'{"css":"@font-face{src:url(\"\/_next\/static\/media\/font.woff2\")"}'
|
||||
result = rewrite(html)
|
||||
assert (
|
||||
r"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"
|
||||
in result
|
||||
)
|
||||
|
||||
def test_rewrites_absolute_next_font_url(self) -> None:
|
||||
html = (
|
||||
'<link rel="preload" as="font" '
|
||||
'href="https://craft-dev.onyx.app/_next/static/media/font.woff2">'
|
||||
)
|
||||
result = rewrite(html)
|
||||
assert f'"{BASE}/_next/static/media/font.woff2"' in result
|
||||
|
||||
def test_rewrites_root_hmr_path(self) -> None:
|
||||
html = 'new WebSocket("wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc")'
|
||||
result = rewrite(html)
|
||||
assert '"wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc"' not in result
|
||||
assert '"/_next/webpack-hmr?id=abc"' in result
|
||||
|
||||
def test_rewrites_escaped_absolute_next_font_url(self) -> None:
|
||||
html = (
|
||||
r'{"href":"https:\/\/craft-dev.onyx.app\/_next\/static\/media\/font.woff2"}'
|
||||
)
|
||||
result = rewrite(html)
|
||||
assert (
|
||||
r'{"href":"\/api\/build\/sessions\/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\/webapp\/_next\/static\/media\/font.woff2"}'
|
||||
in result
|
||||
)
|
||||
|
||||
|
||||
class TestRuntimeFixerInjection:
|
||||
def test_injects_websocket_rewrite_shim(self) -> None:
|
||||
html = "<html><head></head><body></body></html>"
|
||||
result = inject(html)
|
||||
assert "window.WebSocket = function (url, protocols)" in result
|
||||
assert f'var WEBAPP_BASE = "{BASE}"' in result
|
||||
|
||||
def test_injects_hmr_websocket_stub(self) -> None:
|
||||
html = "<html><head></head><body></body></html>"
|
||||
result = inject(html)
|
||||
assert "function MockHmrWebSocket(url)" in result
|
||||
assert "return new MockHmrWebSocket(rewriteNextAssetUrl(url));" in result
|
||||
|
||||
def test_injects_before_head_contents(self) -> None:
|
||||
html = "<html><head><title>x</title></head><body></body></html>"
|
||||
result = inject(html)
|
||||
assert result.index(
|
||||
"window.WebSocket = function (url, protocols)"
|
||||
) < result.index("<title>x</title>")
|
||||
|
||||
def test_rewritten_hmr_url_still_matches_shim_intercept_logic(self) -> None:
|
||||
html = (
|
||||
"<html><head></head><body>"
|
||||
'new WebSocket("wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc")'
|
||||
"</body></html>"
|
||||
)
|
||||
|
||||
rewritten = rewrite(html)
|
||||
assert '"wss://craft-dev.onyx.app/_next/webpack-hmr?id=abc"' not in rewritten
|
||||
assert 'new WebSocket("/_next/webpack-hmr?id=abc")' in rewritten
|
||||
|
||||
injected = inject(rewritten)
|
||||
|
||||
assert 'new WebSocket("/_next/webpack-hmr?id=abc")' in injected
|
||||
assert 'parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0' in injected
|
||||
|
||||
|
||||
class TestProxyHeaderRewriting:
|
||||
def test_rewrites_link_header_font_preload_paths(self) -> None:
|
||||
headers = {
|
||||
"link": (
|
||||
'</_next/static/media/font.woff2>; rel=preload; as="font"; crossorigin, '
|
||||
'</_next/static/media/font2.woff2>; rel=preload; as="font"; crossorigin'
|
||||
)
|
||||
}
|
||||
|
||||
result = _rewrite_proxy_response_headers(headers, SESSION_ID)
|
||||
|
||||
assert f"<{BASE}/_next/static/media/font.woff2>" in result["link"]
|
||||
|
||||
|
||||
class TestProxyRequestWiring:
|
||||
def test_proxy_request_rewrites_link_header_on_html_response(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
html = b"<html><head></head><body>ok</body></html>"
|
||||
upstream = httpx.Response(
|
||||
200,
|
||||
headers={
|
||||
"content-type": "text/html; charset=utf-8",
|
||||
"link": '</_next/static/media/font.woff2>; rel=preload; as="font"',
|
||||
},
|
||||
content=html,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(api, "_get_sandbox_url", lambda *_args: "http://sandbox")
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "FakeClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args: object) -> Literal[False]:
|
||||
return False
|
||||
|
||||
def get(self, _url: str, headers: dict[str, str]) -> httpx.Response:
|
||||
assert "host" not in {key.lower() for key in headers}
|
||||
return upstream
|
||||
|
||||
monkeypatch.setattr(api.httpx, "Client", FakeClient)
|
||||
|
||||
request = cast(Request, SimpleNamespace(headers={}, query_params=""))
|
||||
|
||||
response = api._proxy_request(
|
||||
"", request, UUID(SESSION_ID), cast(Session, SimpleNamespace())
|
||||
)
|
||||
|
||||
assert response.headers["link"] == (
|
||||
f'<{BASE}/_next/static/media/font.woff2>; rel=preload; as="font"'
|
||||
)
|
||||
|
||||
def test_proxy_request_injects_hmr_fixer_for_html_response(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
upstream = httpx.Response(
|
||||
200,
|
||||
headers={"content-type": "text/html; charset=utf-8"},
|
||||
content=b"<html><head><title>x</title></head><body></body></html>",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(api, "_get_sandbox_url", lambda *_args: "http://sandbox")
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, *_args: object, **_kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "FakeClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, *_args: object) -> Literal[False]:
|
||||
return False
|
||||
|
||||
def get(self, _url: str, headers: dict[str, str]) -> httpx.Response:
|
||||
assert "host" not in {key.lower() for key in headers}
|
||||
return upstream
|
||||
|
||||
monkeypatch.setattr(api.httpx, "Client", FakeClient)
|
||||
|
||||
request = cast(Request, SimpleNamespace(headers={}, query_params=""))
|
||||
|
||||
response = api._proxy_request(
|
||||
"", request, UUID(SESSION_ID), cast(Session, SimpleNamespace())
|
||||
)
|
||||
body = cast(bytes, response.body).decode("utf-8")
|
||||
|
||||
assert "window.WebSocket = function (url, protocols)" in body
|
||||
assert body.index("window.WebSocket = function (url, protocols)") < body.index(
|
||||
"<title>x</title>"
|
||||
)
|
||||
|
||||
def test_rewrites_absolute_link_header_font_preload_paths(self) -> None:
|
||||
headers = {
|
||||
"link": (
|
||||
"<https://craft-dev.onyx.app/_next/static/media/font.woff2>; "
|
||||
'rel=preload; as="font"; crossorigin'
|
||||
)
|
||||
}
|
||||
|
||||
result = _rewrite_proxy_response_headers(headers, SESSION_ID)
|
||||
|
||||
assert f"<{BASE}/_next/static/media/font.woff2>" in result["link"]
|
||||
@@ -1,400 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Response
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import GetAccessTokenError
|
||||
|
||||
from onyx.auth.users import CSRF_TOKEN_COOKIE_NAME
|
||||
from onyx.auth.users import CSRF_TOKEN_KEY
|
||||
from onyx.auth.users import get_oauth_router
|
||||
from onyx.auth.users import get_pkce_cookie_name
|
||||
from onyx.auth.users import PKCE_COOKIE_NAME_PREFIX
|
||||
from onyx.auth.users import STATE_TOKEN_AUDIENCE
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
|
||||
|
||||
class _StubOAuthClient:
|
||||
def __init__(self) -> None:
|
||||
self.name = "openid"
|
||||
self.authorization_calls: list[dict[str, str | list[str] | None]] = []
|
||||
self.access_token_calls: list[dict[str, str | None]] = []
|
||||
|
||||
async def get_authorization_url(
|
||||
self,
|
||||
redirect_uri: str,
|
||||
state: str | None = None,
|
||||
scope: list[str] | None = None,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
) -> str:
|
||||
self.authorization_calls.append(
|
||||
{
|
||||
"redirect_uri": redirect_uri,
|
||||
"state": state,
|
||||
"scope": scope,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
}
|
||||
)
|
||||
return f"https://idp.example.com/authorize?state={state}"
|
||||
|
||||
async def get_access_token(
|
||||
self, code: str, redirect_uri: str, code_verifier: str | None = None
|
||||
) -> dict[str, str | int]:
|
||||
self.access_token_calls.append(
|
||||
{
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"access_token": "oidc_access_token",
|
||||
"refresh_token": "oidc_refresh_token",
|
||||
"expires_at": 1730000000,
|
||||
}
|
||||
|
||||
async def get_id_email(self, _access_token: str) -> tuple[str, str | None]:
|
||||
return ("oidc_account_id", "oidc_user@example.com")
|
||||
|
||||
|
||||
def _build_test_client(
|
||||
enable_pkce: bool,
|
||||
login_status_code: int = 302,
|
||||
) -> tuple[TestClient, _StubOAuthClient, MagicMock]:
|
||||
oauth_client = _StubOAuthClient()
|
||||
transport = CookieTransport(cookie_name="testsession")
|
||||
|
||||
async def get_strategy() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
backend = AuthenticationBackend(
|
||||
name="test_backend",
|
||||
transport=transport,
|
||||
get_strategy=get_strategy,
|
||||
)
|
||||
|
||||
login_response = Response(status_code=login_status_code)
|
||||
if login_status_code in {301, 302, 303, 307, 308}:
|
||||
login_response.headers["location"] = "/app"
|
||||
login_response.set_cookie("testsession", "session-token")
|
||||
backend.login = AsyncMock(return_value=login_response) # type: ignore[method-assign]
|
||||
|
||||
user = MagicMock()
|
||||
user.is_active = True
|
||||
user_manager = MagicMock()
|
||||
user_manager.oauth_callback = AsyncMock(return_value=user)
|
||||
user_manager.on_after_login = AsyncMock()
|
||||
|
||||
async def get_user_manager() -> MagicMock:
|
||||
return user_manager
|
||||
|
||||
router = get_oauth_router(
|
||||
oauth_client=cast(BaseOAuth2[Any], oauth_client),
|
||||
backend=backend,
|
||||
get_user_manager=get_user_manager,
|
||||
state_secret="test-secret",
|
||||
redirect_url="http://localhost/auth/oidc/callback",
|
||||
associate_by_email=True,
|
||||
is_verified_by_default=True,
|
||||
enable_pkce=enable_pkce,
|
||||
)
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/auth/oidc")
|
||||
register_onyx_exception_handlers(app)
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
return client, oauth_client, user_manager
|
||||
|
||||
|
||||
def _extract_state_from_authorize_response(response: Any) -> str:
|
||||
auth_url = response.json()["authorization_url"]
|
||||
return parse_qs(urlparse(auth_url).query)["state"][0]
|
||||
|
||||
|
||||
def test_oidc_authorize_omits_pkce_when_flag_disabled() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=False)
|
||||
|
||||
response = client.get("/auth/oidc/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert oauth_client.authorization_calls[0]["code_challenge"] is None
|
||||
assert oauth_client.authorization_calls[0]["code_challenge_method"] is None
|
||||
assert "fastapiusersoauthcsrf" in response.cookies.keys()
|
||||
assert not any(
|
||||
key.startswith(PKCE_COOKIE_NAME_PREFIX) for key in response.cookies.keys()
|
||||
)
|
||||
|
||||
|
||||
def test_oidc_authorize_adds_pkce_when_flag_enabled() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
|
||||
response = client.get("/auth/oidc/authorize")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert oauth_client.authorization_calls[0]["code_challenge"] is not None
|
||||
assert oauth_client.authorization_calls[0]["code_challenge_method"] == "S256"
|
||||
assert any(
|
||||
key.startswith(PKCE_COOKIE_NAME_PREFIX) for key in response.cookies.keys()
|
||||
)
|
||||
|
||||
|
||||
def test_oidc_callback_fails_when_pkce_cookie_missing() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
for key in list(client.cookies.keys()):
|
||||
if key.startswith(PKCE_COOKIE_NAME_PREFIX):
|
||||
del client.cookies[key]
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback", params={"code": "abc123", "state": state}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_rejects_bad_state_before_token_exchange() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
client.get("/auth/oidc/authorize")
|
||||
tampered_state = "not-a-valid-state-jwt"
|
||||
client.cookies.set(get_pkce_cookie_name(tampered_state), "verifier123")
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback", params={"code": "abc123", "state": tampered_state}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_rejects_wrongly_signed_state_before_token_exchange() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
client.get("/auth/oidc/authorize")
|
||||
csrf_token = client.cookies.get(CSRF_TOKEN_COOKIE_NAME)
|
||||
assert csrf_token is not None
|
||||
tampered_state = generate_jwt(
|
||||
{
|
||||
"aud": STATE_TOKEN_AUDIENCE,
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
},
|
||||
"wrong-secret",
|
||||
3600,
|
||||
)
|
||||
client.cookies.set(get_pkce_cookie_name(tampered_state), "verifier123")
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": tampered_state},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert response.json()["detail"] == "ACCESS_TOKEN_DECODE_ERROR"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_rejects_csrf_mismatch_in_pkce_path() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
# Keep PKCE verifier cookie intact, but invalidate CSRF match against state JWT.
|
||||
client.cookies.set("fastapiusersoauthcsrf", "wrong-csrf-token")
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": state},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_get_access_token_error_is_400() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
with patch.object(
|
||||
oauth_client,
|
||||
"get_access_token",
|
||||
AsyncMock(side_effect=GetAccessTokenError("token exchange failed")),
|
||||
):
|
||||
response = client.get(
|
||||
"/auth/oidc/callback", params={"code": "abc123", "state": state}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert response.json()["detail"] == "Authorization code exchange failed"
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_cleans_pkce_cookie_on_idp_error_with_state() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"error": "access_denied", "state": state},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert response.json()["detail"] == "Authorization request failed or was denied"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_cleans_pkce_cookie_on_missing_email() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
with patch.object(
|
||||
oauth_client, "get_id_email", AsyncMock(return_value=("oidc_account_id", None))
|
||||
):
|
||||
response = client.get(
|
||||
"/auth/oidc/callback", params={"code": "abc123", "state": state}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_rejects_wrong_audience_state_before_token_exchange() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=True)
|
||||
client.get("/auth/oidc/authorize")
|
||||
csrf_token = client.cookies.get(CSRF_TOKEN_COOKIE_NAME)
|
||||
assert csrf_token is not None
|
||||
wrong_audience_state = generate_jwt(
|
||||
{
|
||||
"aud": "wrong-audience",
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
},
|
||||
"test-secret",
|
||||
3600,
|
||||
)
|
||||
client.cookies.set(get_pkce_cookie_name(wrong_audience_state), "verifier123")
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": wrong_audience_state},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert response.json()["detail"] == "ACCESS_TOKEN_DECODE_ERROR"
|
||||
assert oauth_client.access_token_calls == []
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_uses_code_verifier_when_pkce_enabled() -> None:
|
||||
client, oauth_client, user_manager = _build_test_client(enable_pkce=True)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
with patch(
|
||||
"onyx.auth.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda _email: "tenant_1",
|
||||
):
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": state},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert response.headers.get("location") == "/"
|
||||
assert oauth_client.access_token_calls[0]["code_verifier"] is not None
|
||||
user_manager.oauth_callback.assert_awaited_once()
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_works_without_pkce_when_flag_disabled() -> None:
|
||||
client, oauth_client, user_manager = _build_test_client(enable_pkce=False)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
with patch(
|
||||
"onyx.auth.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda _email: "tenant_1",
|
||||
):
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": state},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert oauth_client.access_token_calls[0]["code_verifier"] is None
|
||||
user_manager.oauth_callback.assert_awaited_once()
|
||||
|
||||
|
||||
def test_oidc_callback_pkce_preserves_redirect_when_backend_login_is_non_redirect() -> (
|
||||
None
|
||||
):
|
||||
client, oauth_client, user_manager = _build_test_client(
|
||||
enable_pkce=True,
|
||||
login_status_code=200,
|
||||
)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
with patch(
|
||||
"onyx.auth.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda _email: "tenant_1",
|
||||
):
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": state},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert response.headers.get("location") == "/"
|
||||
assert oauth_client.access_token_calls[0]["code_verifier"] is not None
|
||||
user_manager.oauth_callback.assert_awaited_once()
|
||||
assert "Max-Age=0" in response.headers.get("set-cookie", "")
|
||||
|
||||
|
||||
def test_oidc_callback_non_pkce_rejects_csrf_mismatch() -> None:
|
||||
client, oauth_client, _ = _build_test_client(enable_pkce=False)
|
||||
authorize_response = client.get("/auth/oidc/authorize")
|
||||
state = _extract_state_from_authorize_response(authorize_response)
|
||||
|
||||
client.cookies.set(CSRF_TOKEN_COOKIE_NAME, "wrong-csrf-token")
|
||||
|
||||
response = client.get(
|
||||
"/auth/oidc/callback",
|
||||
params={"code": "abc123", "state": state},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert response.json()["error_code"] == "VALIDATION_ERROR"
|
||||
assert response.json()["detail"] == "OAUTH_INVALID_STATE"
|
||||
# NOTE: In the non-PKCE path, oauth2_authorize_callback exchanges the code
|
||||
# before route-body CSRF validation runs. This is a known ordering trade-off.
|
||||
assert oauth_client.access_token_calls
|
||||
@@ -6,7 +6,6 @@ Validates that:
|
||||
- Crash + resume skips already-processed pages
|
||||
- BFS (folder-scoped) drives process all items in one call
|
||||
- 410 Gone triggers a full-resync URL in the checkpoint
|
||||
- Duplicate document IDs across delta pages are deduplicated
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -458,228 +457,3 @@ class TestDeltaPageFetchFailure:
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestDeltaDuplicateDocumentDedup:
|
||||
"""The Microsoft Graph delta API can return the same item on multiple
|
||||
pages. Documents already yielded should be skipped via
|
||||
checkpoint.seen_document_ids."""
|
||||
|
||||
def test_duplicate_across_pages_is_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Item 'dup' appears on both page 1 and page 2. It should only be
|
||||
yielded once."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [_make_item("a"), _make_item("dup")], "https://next2"
|
||||
return [_make_item("dup"), _make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Page 1: yields a, dup
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["a", "dup"]
|
||||
assert "dup" in checkpoint.seen_document_ids
|
||||
|
||||
# Page 2: dup should be skipped, only b yielded
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["b"]
|
||||
|
||||
def test_duplicate_within_same_page_is_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the same item appears twice on a single delta page, only the
|
||||
first occurrence should be yielded."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("x"), _make_item("x"), _make_item("y")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["x", "y"]
|
||||
|
||||
def test_seen_ids_survive_checkpoint_serialization(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""seen_document_ids must survive JSON serialization so that
|
||||
dedup works across crash + resume."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return [_make_item("a")], "https://next2"
|
||||
return [_make_item("a"), _make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Page 1
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
_, checkpoint = _consume_generator(gen)
|
||||
assert "a" in checkpoint.seen_document_ids
|
||||
|
||||
# Simulate crash: round-trip through JSON
|
||||
restored = SharepointConnectorCheckpoint.model_validate_json(
|
||||
checkpoint.model_dump_json()
|
||||
)
|
||||
assert "a" in restored.seen_document_ids
|
||||
|
||||
# Page 2 with restored checkpoint: 'a' should be skipped
|
||||
connector2 = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
gen = connector2._load_from_checkpoint(
|
||||
_START_TS, _END_TS, restored, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert [d.id for d in docs] == ["b"]
|
||||
|
||||
def test_no_dedup_across_separate_indexing_runs(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A fresh checkpoint (new indexing run) should have an empty
|
||||
seen_document_ids, so previously-indexed docs are re-processed."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("a")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# First run
|
||||
cp1 = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, cp1, include_permissions=False
|
||||
)
|
||||
yielded, _ = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
|
||||
# Second run with a fresh checkpoint — same doc should appear again
|
||||
cp2 = _build_ready_checkpoint()
|
||||
assert len(cp2.seen_document_ids) == 0
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, cp2, include_permissions=False
|
||||
)
|
||||
yielded, _ = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
|
||||
def test_same_id_across_drives_not_skipped(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Graph item IDs are only unique within a drive. An item in drive B
|
||||
that happens to share an ID with an item already seen in drive A must
|
||||
NOT be skipped."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("shared-id")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint(drive_names=["DriveA", "DriveB"])
|
||||
|
||||
# Drive A: yields the item
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "shared-id"
|
||||
|
||||
# seen_document_ids should have been cleared when drive A finished
|
||||
assert len(checkpoint.seen_document_ids) == 0
|
||||
|
||||
# Drive B: same ID must be yielded again (different drive)
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "shared-id"
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
"""Unit tests for SharepointConnector._fetch_site_pages 404 handling.
|
||||
|
||||
The Graph Pages API returns 404 for classic sites or sites without
|
||||
modern pages enabled. _fetch_site_pages should gracefully skip these
|
||||
rather than crashing the entire indexing run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from requests import Response
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
SITE_URL = "https://tenant.sharepoint.com/sites/ClassicSite"
|
||||
FAKE_SITE_ID = "tenant.sharepoint.com,abc123,def456"
|
||||
|
||||
|
||||
def _site_descriptor() -> SiteDescriptor:
|
||||
return SiteDescriptor(url=SITE_URL, drive_name=None, folder_path=None)
|
||||
|
||||
|
||||
def _make_http_error(status_code: int) -> HTTPError:
|
||||
response = Response()
|
||||
response.status_code = status_code
|
||||
response._content = b'{"error":{"code":"itemNotFound","message":"Item not found"}}'
|
||||
return HTTPError(response=response)
|
||||
|
||||
|
||||
def _setup_connector(
|
||||
monkeypatch: pytest.MonkeyPatch, # noqa: ARG001
|
||||
) -> SharepointConnector:
|
||||
"""Create a connector with the graph client and site resolution mocked."""
|
||||
connector = SharepointConnector(sites=[SITE_URL])
|
||||
connector.graph_api_base = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
mock_sites = type(
|
||||
"FakeSites",
|
||||
(),
|
||||
{
|
||||
"get_by_url": staticmethod(
|
||||
lambda url: type( # noqa: ARG005
|
||||
"Q",
|
||||
(),
|
||||
{
|
||||
"execute_query": lambda self: None, # noqa: ARG005
|
||||
"id": FAKE_SITE_ID,
|
||||
},
|
||||
)()
|
||||
),
|
||||
},
|
||||
)()
|
||||
connector._graph_client = type("FakeGraphClient", (), {"sites": mock_sites})()
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def _patch_graph_api_get_json(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
fake_fn: Any,
|
||||
) -> None:
|
||||
monkeypatch.setattr(SharepointConnector, "_graph_api_get_json", fake_fn)
|
||||
|
||||
|
||||
class TestFetchSitePages404:
|
||||
def test_404_yields_no_pages(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 from the Pages API should result in zero yielded pages."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert pages == []
|
||||
|
||||
def test_404_does_not_raise(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""A 404 must not propagate as an exception."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
for _ in connector._fetch_site_pages(_site_descriptor()):
|
||||
pass
|
||||
|
||||
def test_non_404_http_error_still_raises(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Non-404 HTTP errors (e.g. 403) must still propagate."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
raise _make_http_error(403)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
with pytest.raises(HTTPError):
|
||||
list(connector._fetch_site_pages(_site_descriptor()))
|
||||
|
||||
def test_successful_fetch_yields_pages(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""When the API succeeds, pages should be yielded normally."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
fake_page = {
|
||||
"id": "page-1",
|
||||
"title": "Hello World",
|
||||
"webUrl": f"{SITE_URL}/SitePages/Hello.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
return {"value": [fake_page]}
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
|
||||
def test_404_on_second_page_stops_pagination(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""If the first API page succeeds but a nextLink returns 404,
|
||||
already-yielded pages are kept and iteration stops cleanly."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
first_page = {
|
||||
"id": "page-1",
|
||||
"title": "First",
|
||||
"webUrl": f"{SITE_URL}/SitePages/First.aspx",
|
||||
"lastModifiedDateTime": "2025-06-01T00:00:00Z",
|
||||
}
|
||||
|
||||
def fake_get_json(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
url: str, # noqa: ARG001
|
||||
params: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> dict[str, Any]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return {
|
||||
"value": [first_page],
|
||||
"@odata.nextLink": "https://graph.microsoft.com/next",
|
||||
}
|
||||
raise _make_http_error(404)
|
||||
|
||||
_patch_graph_api_get_json(monkeypatch, fake_get_json)
|
||||
|
||||
pages = list(connector._fetch_site_pages(_site_descriptor()))
|
||||
assert len(pages) == 1
|
||||
assert pages[0]["id"] == "page-1"
|
||||
@@ -7,7 +7,6 @@ import pytest
|
||||
|
||||
from onyx.db.llm import sync_model_configurations
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.server.manage.llm.models import SyncModelEntry
|
||||
|
||||
|
||||
class TestSyncModelConfigurations:
|
||||
@@ -26,18 +25,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
SyncModelEntry(
|
||||
name="gpt-4",
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
{
|
||||
"name": "gpt-4",
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o",
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -68,18 +67,18 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Existing - should be skipped
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
SyncModelEntry(
|
||||
name="gpt-4o", # New - should be inserted
|
||||
display_name="GPT-4o",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
{
|
||||
"name": "gpt-4", # Existing - should be skipped
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
{
|
||||
"name": "gpt-4o", # New - should be inserted
|
||||
"display_name": "GPT-4o",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -106,12 +105,12 @@ class TestSyncModelConfigurations:
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
models = [
|
||||
SyncModelEntry(
|
||||
name="gpt-4", # Already exists
|
||||
display_name="GPT-4",
|
||||
max_input_tokens=128000,
|
||||
supports_image_input=True,
|
||||
),
|
||||
{
|
||||
"name": "gpt-4", # Already exists
|
||||
"display_name": "GPT-4",
|
||||
"max_input_tokens": 128000,
|
||||
"supports_image_input": True,
|
||||
},
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
@@ -132,7 +131,7 @@ class TestSyncModelConfigurations:
|
||||
sync_model_configurations(
|
||||
db_session=mock_session,
|
||||
provider_name="nonexistent",
|
||||
models=[SyncModelEntry(name="model", display_name="Model")],
|
||||
models=[{"name": "model", "display_name": "Model"}],
|
||||
)
|
||||
|
||||
def test_handles_missing_optional_fields(self) -> None:
|
||||
@@ -146,12 +145,12 @@ class TestSyncModelConfigurations:
|
||||
with patch(
|
||||
"onyx.db.llm.fetch_existing_llm_provider", return_value=mock_provider
|
||||
):
|
||||
# Model with only required fields (max_input_tokens and supports_image_input default)
|
||||
# Model with only required fields
|
||||
models = [
|
||||
SyncModelEntry(
|
||||
name="model-1",
|
||||
display_name="Model 1",
|
||||
),
|
||||
{
|
||||
"name": "model-1",
|
||||
# No display_name, max_input_tokens, or supports_image_input
|
||||
},
|
||||
]
|
||||
|
||||
result = sync_model_configurations(
|
||||
|
||||
@@ -1,507 +0,0 @@
|
||||
"""Unit tests for onyx.db.voice module."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.models import VoiceProvider
|
||||
from onyx.db.voice import deactivate_stt_provider
|
||||
from onyx.db.voice import deactivate_tts_provider
|
||||
from onyx.db.voice import delete_voice_provider
|
||||
from onyx.db.voice import fetch_default_stt_provider
|
||||
from onyx.db.voice import fetch_default_tts_provider
|
||||
from onyx.db.voice import fetch_voice_provider_by_id
|
||||
from onyx.db.voice import fetch_voice_provider_by_type
|
||||
from onyx.db.voice import fetch_voice_providers
|
||||
from onyx.db.voice import MAX_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import MIN_VOICE_PLAYBACK_SPEED
|
||||
from onyx.db.voice import set_default_stt_provider
|
||||
from onyx.db.voice import set_default_tts_provider
|
||||
from onyx.db.voice import update_user_voice_settings
|
||||
from onyx.db.voice import upsert_voice_provider
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
def _make_voice_provider(
|
||||
id: int = 1,
|
||||
name: str = "Test Provider",
|
||||
provider_type: str = "openai",
|
||||
is_default_stt: bool = False,
|
||||
is_default_tts: bool = False,
|
||||
) -> VoiceProvider:
|
||||
"""Create a VoiceProvider instance for testing."""
|
||||
provider = VoiceProvider()
|
||||
provider.id = id
|
||||
provider.name = name
|
||||
provider.provider_type = provider_type
|
||||
provider.is_default_stt = is_default_stt
|
||||
provider.is_default_tts = is_default_tts
|
||||
provider.api_key = None
|
||||
provider.api_base = None
|
||||
provider.custom_config = None
|
||||
provider.stt_model = None
|
||||
provider.tts_model = None
|
||||
provider.default_voice = None
|
||||
return provider
|
||||
|
||||
|
||||
class TestFetchVoiceProviders:
|
||||
"""Tests for fetch_voice_providers."""
|
||||
|
||||
def test_returns_all_providers(self, mock_db_session: MagicMock) -> None:
|
||||
providers = [
|
||||
_make_voice_provider(id=1, name="Provider A"),
|
||||
_make_voice_provider(id=2, name="Provider B"),
|
||||
]
|
||||
mock_db_session.scalars.return_value.all.return_value = providers
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == providers
|
||||
mock_db_session.scalars.assert_called_once()
|
||||
|
||||
def test_returns_empty_list_when_no_providers(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalars.return_value.all.return_value = []
|
||||
|
||||
result = fetch_voice_providers(mock_db_session)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestFetchVoiceProviderById:
|
||||
"""Tests for fetch_voice_provider_by_id."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 1)
|
||||
|
||||
assert result is provider
|
||||
mock_db_session.scalar.assert_called_once()
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_id(mock_db_session, 999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchDefaultProviders:
|
||||
"""Tests for fetch_default_stt_provider and fetch_default_tts_provider."""
|
||||
|
||||
def test_fetch_default_stt_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_stt_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_stt_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_fetch_default_tts_provider_returns_provider(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_fetch_default_tts_provider_returns_none_when_no_default(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_default_tts_provider(mock_db_session)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFetchVoiceProviderByType:
|
||||
"""Tests for fetch_voice_provider_by_type."""
|
||||
|
||||
def test_returns_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1, provider_type="openai")
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "openai")
|
||||
|
||||
assert result is provider
|
||||
|
||||
def test_returns_none_when_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
result = fetch_voice_provider_by_type(mock_db_session, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertVoiceProvider:
|
||||
"""Tests for upsert_voice_provider."""
|
||||
|
||||
def test_creates_new_provider_when_no_id(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=None,
|
||||
name="New Provider",
|
||||
provider_type="openai",
|
||||
api_key="test-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.flush.assert_called()
|
||||
added_obj = mock_db_session.add.call_args[0][0]
|
||||
assert added_obj.name == "New Provider"
|
||||
assert added_obj.provider_type == "openai"
|
||||
|
||||
def test_updates_existing_provider(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1, name="Old Name")
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Updated Name",
|
||||
provider_type="elevenlabs",
|
||||
api_key="new-key",
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
mock_db_session.add.assert_not_called()
|
||||
assert existing_provider.name == "Updated Name"
|
||||
assert existing_provider.provider_type == "elevenlabs"
|
||||
|
||||
def test_raises_when_provider_not_found(self, mock_db_session: MagicMock) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=999,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_does_not_update_api_key_when_not_changed(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
existing_provider.api_key = "original-key" # type: ignore[assignment]
|
||||
original_api_key = existing_provider.api_key
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key="new-key",
|
||||
api_key_changed=False,
|
||||
)
|
||||
|
||||
# api_key should remain unchanged (same object reference)
|
||||
assert existing_provider.api_key is original_api_key
|
||||
|
||||
def test_activates_stt_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_stt=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_stt is True
|
||||
|
||||
def test_activates_tts_when_requested(self, mock_db_session: MagicMock) -> None:
|
||||
existing_provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = existing_provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
mock_db_session.execute.return_value = None
|
||||
|
||||
upsert_voice_provider(
|
||||
db_session=mock_db_session,
|
||||
provider_id=1,
|
||||
name="Test",
|
||||
provider_type="openai",
|
||||
api_key=None,
|
||||
api_key_changed=False,
|
||||
activate_tts=True,
|
||||
)
|
||||
|
||||
assert existing_provider.is_default_tts is True
|
||||
|
||||
|
||||
class TestDeleteVoiceProvider:
|
||||
"""Tests for delete_voice_provider."""
|
||||
|
||||
def test_soft_deletes_provider_when_found(self, mock_db_session: MagicMock) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
|
||||
delete_voice_provider(mock_db_session, 1)
|
||||
|
||||
assert provider.deleted is True
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_provider_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
delete_voice_provider(mock_db_session, 999)
|
||||
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSetDefaultProviders:
|
||||
"""Tests for set_default_stt_provider and set_default_tts_provider."""
|
||||
|
||||
def test_set_default_stt_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_stt is True
|
||||
|
||||
def test_set_default_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_set_default_tts_provider_deactivates_others(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
assert result.is_default_tts is True
|
||||
|
||||
def test_set_default_tts_provider_updates_model_when_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.execute.return_value = None
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = set_default_tts_provider(
|
||||
db_session=mock_db_session, provider_id=1, tts_model="tts-1-hd"
|
||||
)
|
||||
|
||||
assert result.tts_model == "tts-1-hd"
|
||||
|
||||
def test_set_default_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
set_default_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDeactivateProviders:
|
||||
"""Tests for deactivate_stt_provider and deactivate_tts_provider."""
|
||||
|
||||
def test_deactivate_stt_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_stt=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_stt_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_stt is False
|
||||
|
||||
def test_deactivate_stt_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_stt_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
def test_deactivate_tts_provider_sets_false(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
provider = _make_voice_provider(id=1, is_default_tts=True)
|
||||
mock_db_session.scalar.return_value = provider
|
||||
mock_db_session.flush.return_value = None
|
||||
mock_db_session.refresh.return_value = None
|
||||
|
||||
result = deactivate_tts_provider(db_session=mock_db_session, provider_id=1)
|
||||
|
||||
assert result.is_default_tts is False
|
||||
|
||||
def test_deactivate_tts_provider_raises_when_not_found(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
deactivate_tts_provider(db_session=mock_db_session, provider_id=999)
|
||||
|
||||
assert "No voice provider with id 999" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestUpdateUserVoiceSettings:
|
||||
"""Tests for update_user_voice_settings."""
|
||||
|
||||
def test_updates_auto_send(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_send=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_auto_playback(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, auto_playback=True)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_updates_playback_speed_within_range(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=1.5)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
|
||||
def test_clamps_playback_speed_to_min(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=0.1)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MIN_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_clamps_playback_speed_to_max(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id, playback_speed=5.0)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
stmt = mock_db_session.execute.call_args[0][0]
|
||||
compiled = stmt.compile(compile_kwargs={"literal_binds": True})
|
||||
assert str(MAX_VOICE_PLAYBACK_SPEED) in str(compiled)
|
||||
|
||||
def test_updates_multiple_settings(self, mock_db_session: MagicMock) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(
|
||||
mock_db_session,
|
||||
user_id,
|
||||
auto_send=True,
|
||||
auto_playback=False,
|
||||
playback_speed=1.25,
|
||||
)
|
||||
|
||||
mock_db_session.execute.assert_called_once()
|
||||
mock_db_session.flush.assert_called_once()
|
||||
|
||||
def test_does_nothing_when_no_settings_provided(
|
||||
self, mock_db_session: MagicMock
|
||||
) -> None:
|
||||
user_id = uuid4()
|
||||
|
||||
update_user_voice_settings(mock_db_session, user_id)
|
||||
|
||||
mock_db_session.execute.assert_not_called()
|
||||
mock_db_session.flush.assert_not_called()
|
||||
|
||||
|
||||
class TestSpeedClampingLogic:
|
||||
"""Tests for the speed clamping constants and logic."""
|
||||
|
||||
def test_min_speed_constant(self) -> None:
|
||||
assert MIN_VOICE_PLAYBACK_SPEED == 0.5
|
||||
|
||||
def test_max_speed_constant(self) -> None:
|
||||
assert MAX_VOICE_PLAYBACK_SPEED == 2.0
|
||||
|
||||
def test_clamping_formula(self) -> None:
|
||||
"""Verify the clamping formula used in update_user_voice_settings."""
|
||||
test_cases = [
|
||||
(0.1, MIN_VOICE_PLAYBACK_SPEED),
|
||||
(0.5, 0.5),
|
||||
(1.0, 1.0),
|
||||
(1.5, 1.5),
|
||||
(2.0, 2.0),
|
||||
(3.0, MAX_VOICE_PLAYBACK_SPEED),
|
||||
]
|
||||
for speed, expected in test_cases:
|
||||
clamped = max(
|
||||
MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, speed)
|
||||
)
|
||||
assert (
|
||||
clamped == expected
|
||||
), f"speed={speed} expected={expected} got={clamped}"
|
||||
196
backend/tests/unit/onyx/file_processing/test_xlsx_to_text.py
Normal file
196
backend/tests/unit/onyx/file_processing/test_xlsx_to_text.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import io
|
||||
|
||||
import openpyxl
|
||||
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
|
||||
|
||||
def _make_xlsx(sheets: dict[str, list[list[str]]]) -> io.BytesIO:
|
||||
"""Create an in-memory xlsx file from a dict of sheet_name -> matrix of strings."""
|
||||
wb = openpyxl.Workbook()
|
||||
if wb.active is not None:
|
||||
wb.remove(wb.active)
|
||||
for sheet_name, rows in sheets.items():
|
||||
ws = wb.create_sheet(title=sheet_name)
|
||||
for row in rows:
|
||||
ws.append(row)
|
||||
buf = io.BytesIO()
|
||||
wb.save(buf)
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
class TestXlsxToText:
|
||||
def test_single_sheet_basic(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["Name", "Age"],
|
||||
["Alice", "30"],
|
||||
["Bob", "25"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "Name" in lines[0]
|
||||
assert "Age" in lines[0]
|
||||
assert "Alice" in lines[1]
|
||||
assert "30" in lines[1]
|
||||
assert "Bob" in lines[2]
|
||||
|
||||
def test_multiple_sheets_separated(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [["a", "b"]],
|
||||
"Sheet2": [["c", "d"]],
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# TEXT_SECTION_SEPARATOR is "\n\n"
|
||||
assert "\n\n" in result
|
||||
parts = result.split("\n\n")
|
||||
assert any("a" in p for p in parts)
|
||||
assert any("c" in p for p in parts)
|
||||
|
||||
def test_empty_cells(self) -> None:
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "b"],
|
||||
["", "c", ""],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
def test_commas_in_cells_are_quoted(self) -> None:
|
||||
"""Cells containing commas should be quoted in CSV output."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["hello, world", "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert '"hello, world"' in result
|
||||
|
||||
def test_empty_workbook(self) -> None:
|
||||
xlsx = _make_xlsx({"Sheet1": []})
|
||||
result = xlsx_to_text(xlsx)
|
||||
assert result.strip() == ""
|
||||
|
||||
def test_long_empty_row_run_capped(self) -> None:
|
||||
"""Runs of >2 empty rows should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["header"],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
[""],
|
||||
["data"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 4 empty rows capped to 2, so: header + 2 empty + data = 4 lines
|
||||
assert len(lines) == 4
|
||||
assert "header" in lines[0]
|
||||
assert "data" in lines[-1]
|
||||
|
||||
def test_long_empty_col_run_capped(self) -> None:
|
||||
"""Runs of >2 empty columns should be capped to 2."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "", "", "", "b"],
|
||||
["c", "", "", "", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 2
|
||||
# Each row should have 4 fields (a + 2 empty + b), not 5
|
||||
# csv format: a,,,b (3 commas = 4 fields)
|
||||
first_line = lines[0].strip()
|
||||
# Count commas to verify column reduction
|
||||
assert first_line.count(",") == 3
|
||||
|
||||
def test_short_empty_runs_kept(self) -> None:
|
||||
"""Runs of <=2 empty rows/cols should be preserved."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["a", "b"],
|
||||
["", ""],
|
||||
["", ""],
|
||||
["c", "d"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# All 4 rows preserved (2 empty rows <= threshold)
|
||||
assert len(lines) == 4
|
||||
|
||||
def test_bad_zip_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="test.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_bad_zip_tilde_file_returns_empty(self) -> None:
|
||||
bad_file = io.BytesIO(b"not a zip file")
|
||||
result = xlsx_to_text(bad_file, file_name="~$temp.xlsx")
|
||||
assert result == ""
|
||||
|
||||
def test_large_sparse_sheet(self) -> None:
|
||||
"""A sheet with data, a big empty gap, and more data — gap is capped to 2."""
|
||||
rows: list[list[str]] = [["row1_data"]]
|
||||
rows.extend([[""] for _ in range(10)])
|
||||
rows.append(["row2_data"])
|
||||
xlsx = _make_xlsx({"Sheet1": rows})
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
# 10 empty rows capped to 2: row1_data + 2 empty + row2_data = 4
|
||||
assert len(lines) == 4
|
||||
assert "row1_data" in lines[0]
|
||||
assert "row2_data" in lines[-1]
|
||||
|
||||
def test_quotes_in_cells(self) -> None:
|
||||
"""Cells containing quotes should be properly escaped."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
['say "hello"', "normal"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
# csv.writer escapes quotes by doubling them
|
||||
assert '""hello""' in result
|
||||
|
||||
def test_each_row_is_separate_line(self) -> None:
|
||||
"""Each row should produce its own line (regression for writerow vs writerows)."""
|
||||
xlsx = _make_xlsx(
|
||||
{
|
||||
"Sheet1": [
|
||||
["r1c1", "r1c2"],
|
||||
["r2c1", "r2c2"],
|
||||
["r3c1", "r3c2"],
|
||||
]
|
||||
}
|
||||
)
|
||||
result = xlsx_to_text(xlsx)
|
||||
lines = [line for line in result.strip().split("\n") if line.strip()]
|
||||
assert len(lines) == 3
|
||||
assert "r1c1" in lines[0] and "r1c2" in lines[0]
|
||||
assert "r2c1" in lines[1] and "r2c2" in lines[1]
|
||||
assert "r3c1" in lines[2] and "r3c2" in lines[2]
|
||||
@@ -26,6 +26,14 @@ class TestIsTrueOpenAIModel:
|
||||
"""Test that real OpenAI GPT-4o-mini model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "gpt-4o-mini") is True
|
||||
|
||||
def test_real_openai_o1_preview(self) -> None:
|
||||
"""Test that real OpenAI o1-preview reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-preview") is True
|
||||
|
||||
def test_real_openai_o1_mini(self) -> None:
|
||||
"""Test that real OpenAI o1-mini reasoning model is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "o1-mini") is True
|
||||
|
||||
def test_openai_with_provider_prefix(self) -> None:
|
||||
"""Test that OpenAI model with provider prefix is correctly identified."""
|
||||
assert is_true_openai_model(LlmProviderNames.OPENAI, "openai/gpt-4") is False
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
"""Tests for Slack channel reference resolution and tag filtering
|
||||
in handle_regular_answer.py."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_client_with_channels(
|
||||
channel_map: dict[str, str],
|
||||
) -> MagicMock:
|
||||
"""Return a mock WebClient where conversations_info resolves IDs to names."""
|
||||
client = MagicMock()
|
||||
|
||||
def _conversations_info(channel: str) -> MagicMock:
|
||||
if channel in channel_map:
|
||||
resp = MagicMock()
|
||||
resp.validate = MagicMock()
|
||||
resp.__getitem__ = lambda _self, key: {
|
||||
"channel": {
|
||||
"name": channel_map[channel],
|
||||
"is_im": False,
|
||||
"is_mpim": False,
|
||||
}
|
||||
}[key]
|
||||
return resp
|
||||
raise SlackApiError("channel_not_found", response=MagicMock())
|
||||
|
||||
client.conversations_info = _conversations_info
|
||||
return client
|
||||
|
||||
|
||||
def _mock_logger() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SLACK_CHANNEL_REF_PATTERN regex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackChannelRefPattern:
|
||||
def test_matches_bare_channel_id(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
|
||||
assert matches == [("C097NBWMY8Y", "")]
|
||||
|
||||
def test_matches_channel_id_with_name(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
|
||||
assert matches == [("C097NBWMY8Y", "eng-infra")]
|
||||
|
||||
def test_matches_multiple_channels(self) -> None:
|
||||
msg = "compare <#C111AAA> and <#C222BBB|general>"
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
|
||||
assert len(matches) == 2
|
||||
assert ("C111AAA", "") in matches
|
||||
assert ("C222BBB", "general") in matches
|
||||
|
||||
def test_no_match_on_plain_text(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
|
||||
assert matches == []
|
||||
|
||||
def test_no_match_on_user_mention(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
|
||||
assert matches == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_channel_references tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveChannelReferences:
|
||||
def test_resolves_bare_channel_id_via_api(self) -> None:
|
||||
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summary of <#C097NBWMY8Y> this week",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summary of #eng-infra this week"
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
|
||||
|
||||
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#C097NBWMY8Y|eng-infra> for updates",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "check #eng-infra for updates"
|
||||
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
|
||||
# Should NOT have called the API since name was in the markup
|
||||
client.conversations_info.assert_not_called()
|
||||
|
||||
def test_multiple_channels(self) -> None:
|
||||
client = _mock_client_with_channels(
|
||||
{
|
||||
"C111AAA": "eng-infra",
|
||||
"C222BBB": "eng-general",
|
||||
}
|
||||
)
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#eng-general" in message
|
||||
assert "<#" not in message
|
||||
assert len(tags) == 2
|
||||
tag_values = {t.tag_value for t in tags}
|
||||
assert tag_values == {"eng-infra", "eng-general"}
|
||||
|
||||
def test_no_channel_references_returns_unchanged(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="just a normal message with no channels",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "just a normal message with no channels"
|
||||
assert tags == []
|
||||
|
||||
def test_api_failure_skips_channel_gracefully(self) -> None:
|
||||
# Client that fails for all channel lookups
|
||||
client = _mock_client_with_channels({})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Message should remain unchanged for the failed channel
|
||||
assert "<#CBADID123>" in message
|
||||
assert tags == []
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
def test_partial_failure_resolves_what_it_can(self) -> None:
|
||||
# Only one of two channels resolves
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "<#CBADID123>" in message # failed one stays raw
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_duplicate_channel_produces_single_tag(self) -> None:
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summarize <#C111AAA> and compare with <#C111AAA>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summarize #eng-infra and compare with #eng-infra"
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_mixed_pipe_and_bare_formats(self) -> None:
|
||||
client = _mock_client_with_channels({"C222BBB": "random"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="see <#C111AAA|eng-infra> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#random" in message
|
||||
assert len(tags) == 2
|
||||
@@ -1,19 +1,15 @@
|
||||
"""Tests for LLM model fetch endpoints.
|
||||
|
||||
These tests verify the full request/response flow for fetching models
|
||||
from dynamic providers (Ollama, OpenRouter, Litellm), including the
|
||||
from dynamic providers (Ollama, OpenRouter), including the
|
||||
sync-to-DB behavior when provider_name is specified.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.llm.models import LitellmFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LitellmModelsRequest
|
||||
from onyx.server.manage.llm.models import LMStudioFinalModelResponse
|
||||
from onyx.server.manage.llm.models import LMStudioModelsRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
@@ -618,283 +614,3 @@ class TestGetLMStudioAvailableModels:
|
||||
request = LMStudioModelsRequest(api_base="http://localhost:1234")
|
||||
with pytest.raises(OnyxError):
|
||||
get_lm_studio_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
|
||||
class TestGetLitellmAvailableModels:
|
||||
"""Tests for the Litellm proxy model fetch endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm_response(self) -> dict:
|
||||
"""Mock response from Litellm /v1/models endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
{
|
||||
"id": "claude-3-5-sonnet",
|
||||
"object": "model",
|
||||
"created": 1700000001,
|
||||
"owned_by": "anthropic",
|
||||
},
|
||||
{
|
||||
"id": "gemini-pro",
|
||||
"object": "model",
|
||||
"created": 1700000002,
|
||||
"owned_by": "google",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
def test_returns_model_list(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that endpoint returns properly formatted model list."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(r, LitellmFinalModelResponse) for r in results)
|
||||
|
||||
def test_model_fields_parsed_correctly(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that provider_name and model_name are correctly extracted."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
gpt = next(r for r in results if r.model_name == "gpt-4o")
|
||||
assert gpt.provider_name == "openai"
|
||||
|
||||
claude = next(r for r in results if r.model_name == "claude-3-5-sonnet")
|
||||
assert claude.provider_name == "anthropic"
|
||||
|
||||
def test_results_sorted_by_model_name(self, mock_litellm_response: dict) -> None:
|
||||
"""Test that results are alphabetically sorted by model_name."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
model_names = [r.model_name for r in results]
|
||||
assert model_names == sorted(model_names, key=str.lower)
|
||||
|
||||
def test_empty_data_raises_onyx_error(self) -> None:
|
||||
"""Test that empty model list raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"data": []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No models found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_missing_data_key_raises_onyx_error(self) -> None:
|
||||
"""Test that response without 'data' key raises OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_skips_unparseable_entries(self) -> None:
|
||||
"""Test that malformed model entries are skipped without failing."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_with_bad_entry = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
# Missing required fields
|
||||
{"bad_field": "bad_value"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_with_bad_entry
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
results = get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].model_name == "gpt-4o"
|
||||
|
||||
def test_all_entries_unparseable_raises_onyx_error(self) -> None:
|
||||
"""Test that OnyxError is raised when all entries fail to parse."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
response_all_bad = {
|
||||
"data": [
|
||||
{"bad_field": "bad_value"},
|
||||
{"another_bad": 123},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = response_all_bad
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="No compatible models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_api_base_trailing_slash_handled(self) -> None:
|
||||
"""Test that trailing slashes in api_base are handled correctly."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_litellm_response = {
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1700000000,
|
||||
"owned_by": "openai",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = mock_litellm_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000/",
|
||||
api_key="test-key",
|
||||
)
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
# Should call /v1/models without double slashes
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[0][0] == "http://localhost:4000/v1/models"
|
||||
|
||||
def test_connection_failure_raises_onyx_error(self) -> None:
|
||||
"""Test that connection failures are wrapped in OnyxError."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_get.side_effect = Exception("Connection refused")
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Failed to fetch LiteLLM models"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_401_raises_authentication_error(self) -> None:
|
||||
"""Test that a 401 response raises OnyxError with authentication message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Unauthorized", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="bad-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="Authentication failed"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
def test_404_raises_not_found_error(self) -> None:
|
||||
"""Test that a 404 response raises OnyxError with endpoint not found message."""
|
||||
from onyx.server.manage.llm.api import get_litellm_available_models
|
||||
|
||||
mock_session = MagicMock()
|
||||
|
||||
with patch("onyx.server.manage.llm.api.httpx.get") as mock_get:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
mock_get.side_effect = httpx.HTTPStatusError(
|
||||
"Not Found", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
request = LitellmModelsRequest(
|
||||
api_base="http://localhost:4000",
|
||||
api_key="test-key",
|
||||
)
|
||||
with pytest.raises(OnyxError, match="endpoint not found"):
|
||||
get_litellm_available_models(request, MagicMock(), mock_session)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.manage.voice.api import _validate_voice_api_base
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_private_for_non_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("openai", "http://127.0.0.1:11434")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_allows_private_for_azure() -> None:
|
||||
validated = _validate_voice_api_base("azure", "http://127.0.0.1:5000")
|
||||
assert validated == "http://127.0.0.1:5000"
|
||||
|
||||
|
||||
def test_validate_voice_api_base_blocks_metadata_for_azure() -> None:
|
||||
with pytest.raises(OnyxError, match="Invalid target URI"):
|
||||
_validate_voice_api_base("azure", "http://metadata.google.internal/")
|
||||
|
||||
|
||||
def test_validate_voice_api_base_returns_none_for_none() -> None:
|
||||
assert _validate_voice_api_base("openai", None) is None
|
||||
@@ -1,54 +0,0 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
|
||||
def _mock_user(
|
||||
personal_name: str | None = "Test User",
|
||||
created_at: datetime.datetime | None = None,
|
||||
updated_at: datetime.datetime | None = None,
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = "test@example.com"
|
||||
user.role = UserRole.BASIC
|
||||
user.is_active = True
|
||||
user.password_configured = True
|
||||
user.personal_name = personal_name
|
||||
user.created_at = created_at or datetime.datetime(
|
||||
2025, 1, 1, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def test_from_user_model_includes_new_fields() -> None:
|
||||
user = _mock_user(personal_name="Alice")
|
||||
groups = [UserGroupInfo(id=1, name="Engineering")]
|
||||
|
||||
snapshot = FullUserSnapshot.from_user_model(user, groups=groups)
|
||||
|
||||
assert snapshot.personal_name == "Alice"
|
||||
assert snapshot.created_at == user.created_at
|
||||
assert snapshot.updated_at == user.updated_at
|
||||
assert snapshot.groups == groups
|
||||
|
||||
|
||||
def test_from_user_model_defaults_groups_to_empty() -> None:
|
||||
user = _mock_user()
|
||||
snapshot = FullUserSnapshot.from_user_model(user)
|
||||
|
||||
assert snapshot.groups == []
|
||||
|
||||
|
||||
def test_from_user_model_personal_name_none() -> None:
|
||||
user = _mock_user(personal_name=None)
|
||||
snapshot = FullUserSnapshot.from_user_model(user)
|
||||
|
||||
assert snapshot.personal_name is None
|
||||
@@ -186,42 +186,3 @@ def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_python_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
py_source = b'def hello():\n print("world")\n'
|
||||
monkeypatch.setattr(
|
||||
utils, "extract_file_text", lambda **_kwargs: py_source.decode()
|
||||
)
|
||||
|
||||
upload = _make_upload("script.py", size=len(py_source), content=py_source)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert result.acceptable[0].filename == "script.py"
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_binary_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 10_000)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
monkeypatch.setattr(utils, "extract_file_text", lambda **_kwargs: "")
|
||||
|
||||
binary_content = bytes(range(256)) * 4
|
||||
upload = _make_upload("data.bin", size=len(binary_content), content=binary_content)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "data.bin"
|
||||
assert "Unsupported file type" in result.rejected[0].reason
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.utils.url import _is_ip_private_or_reserved
|
||||
from onyx.utils.url import _validate_and_resolve_url
|
||||
from onyx.utils.url import ssrf_safe_get
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
|
||||
class TestIsIpPrivateOrReserved:
|
||||
@@ -306,22 +305,3 @@ class TestSsrfSafeGet:
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
|
||||
class TestValidateOutboundHttpUrl:
|
||||
def test_rejects_private_ip_by_default(self) -> None:
|
||||
with pytest.raises(SSRFException, match="internal/private IP"):
|
||||
validate_outbound_http_url("http://127.0.0.1:8000")
|
||||
|
||||
def test_allows_private_ip_when_explicitly_enabled(self) -> None:
|
||||
validated_url = validate_outbound_http_url(
|
||||
"http://127.0.0.1:8000", allow_private_network=True
|
||||
)
|
||||
assert validated_url == "http://127.0.0.1:8000"
|
||||
|
||||
def test_blocks_metadata_hostname_when_private_is_enabled(self) -> None:
|
||||
with pytest.raises(SSRFException, match="not allowed"):
|
||||
validate_outbound_http_url(
|
||||
"http://metadata.google.internal/latest",
|
||||
allow_private_network=True,
|
||||
)
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
def test_azure_provider_extracts_region_from_target_uri() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://westus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider.speech_region == "westus"
|
||||
|
||||
|
||||
def test_azure_provider_normalizes_uppercase_region() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "WestUS2"},
|
||||
)
|
||||
assert provider.speech_region == "westus2"
|
||||
|
||||
|
||||
def test_azure_provider_rejects_invalid_speech_region() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base=None,
|
||||
custom_config={"speech_region": "westus/../../etc"},
|
||||
)
|
||||
@@ -1,194 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.voice.providers.azure import AzureVoiceProvider
|
||||
|
||||
|
||||
# --- _is_azure_cloud_url ---
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_speech_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_cognitive_microsoft() -> None:
|
||||
assert AzureVoiceProvider._is_azure_cloud_url(
|
||||
"https://westus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_custom_host() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url("https://my-custom-host.com/")
|
||||
|
||||
|
||||
def test_is_azure_cloud_url_rejects_none() -> None:
|
||||
assert not AzureVoiceProvider._is_azure_cloud_url(None)
|
||||
|
||||
|
||||
# --- _extract_speech_region_from_uri ---
|
||||
|
||||
|
||||
def test_extract_region_from_tts_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_from_cognitive_api_url() -> None:
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://eastus.api.cognitive.microsoft.com/"
|
||||
)
|
||||
== "eastus"
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_custom_domain() -> None:
|
||||
"""Custom domains use resource name, not region — must use speech_region config."""
|
||||
assert (
|
||||
AzureVoiceProvider._extract_speech_region_from_uri(
|
||||
"https://myresource.cognitiveservices.azure.com/"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def test_extract_region_returns_none_for_none() -> None:
|
||||
assert AzureVoiceProvider._extract_speech_region_from_uri(None) is None
|
||||
|
||||
|
||||
# --- _validate_speech_region ---
|
||||
|
||||
|
||||
def test_validate_region_normalizes_to_lowercase() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("WestUS2") == "westus2"
|
||||
|
||||
|
||||
def test_validate_region_accepts_hyphens() -> None:
|
||||
assert AzureVoiceProvider._validate_speech_region("us-east-1") == "us-east-1"
|
||||
|
||||
|
||||
def test_validate_region_rejects_path_traversal() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("westus/../../etc")
|
||||
|
||||
|
||||
def test_validate_region_rejects_dots() -> None:
|
||||
with pytest.raises(ValueError, match="Invalid Azure speech_region"):
|
||||
AzureVoiceProvider._validate_speech_region("west.us")
|
||||
|
||||
|
||||
# --- _pcm16_to_wav ---
|
||||
|
||||
|
||||
def test_pcm16_to_wav_produces_valid_wav() -> None:
|
||||
samples = [32767, -32768, 0, 1234]
|
||||
pcm_data = struct.pack(f"<{len(samples)}h", *samples)
|
||||
wav_bytes = AzureVoiceProvider._pcm16_to_wav(pcm_data, sample_rate=16000)
|
||||
|
||||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 16000
|
||||
frames = wav_file.readframes(4)
|
||||
recovered = struct.unpack(f"<{len(samples)}h", frames)
|
||||
assert list(recovered) == samples
|
||||
|
||||
|
||||
# --- URL Construction ---
|
||||
|
||||
|
||||
def test_get_tts_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "eastus"}
|
||||
)
|
||||
assert (
|
||||
provider._get_tts_url()
|
||||
== "https://eastus.tts.speech.microsoft.com/cognitiveservices/v1"
|
||||
)
|
||||
|
||||
|
||||
def test_get_stt_url_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base=None, custom_config={"speech_region": "westus2"}
|
||||
)
|
||||
assert "westus2.stt.speech.microsoft.com" in provider._get_stt_url()
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
def test_get_tts_url_self_hosted_strips_trailing_slash() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000/", custom_config={}
|
||||
)
|
||||
assert provider._get_tts_url() == "http://localhost:5000/cognitiveservices/v1"
|
||||
|
||||
|
||||
# --- _is_self_hosted ---
|
||||
|
||||
|
||||
def test_is_self_hosted_true_for_custom_endpoint() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key", api_base="http://localhost:5000", custom_config={}
|
||||
)
|
||||
assert provider._is_self_hosted() is True
|
||||
|
||||
|
||||
def test_is_self_hosted_false_for_azure_cloud() -> None:
|
||||
provider = AzureVoiceProvider(
|
||||
api_key="key",
|
||||
api_base="https://eastus.api.cognitive.microsoft.com/",
|
||||
custom_config={},
|
||||
)
|
||||
assert provider._is_self_hosted() is False
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
assert len(result) // 2 == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_empty_data() -> None:
|
||||
from onyx.voice.providers.azure import AzureStreamingTranscriber
|
||||
|
||||
t = AzureStreamingTranscriber.__new__(AzureStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
assert t._resample_pcm16(b"") == b""
|
||||
@@ -1,117 +0,0 @@
|
||||
import struct
|
||||
|
||||
from onyx.voice.providers.elevenlabs import _http_to_ws_url
|
||||
from onyx.voice.providers.elevenlabs import DEFAULT_ELEVENLABS_API_BASE
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsSTTMessageType
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.elevenlabs.io") == "wss://api.elevenlabs.io"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:8080") == "ws://localhost:8080"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_other_schemes() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
def test_http_to_ws_url_preserves_path() -> None:
|
||||
assert (
|
||||
_http_to_ws_url("https://api.elevenlabs.io/v1/tts")
|
||||
== "wss://api.elevenlabs.io/v1/tts"
|
||||
)
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_stt_message_type_compares_as_string() -> None:
|
||||
"""StrEnum members should work in string comparisons (e.g. from JSON)."""
|
||||
assert str(ElevenLabsSTTMessageType.COMMITTED_TRANSCRIPT) == "committed_transcript"
|
||||
assert isinstance(ElevenLabsSTTMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- Resampling ---
|
||||
|
||||
|
||||
def test_resample_pcm16_passthrough_when_same_rate() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 16000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
data = struct.pack("<4h", 100, 200, 300, 400)
|
||||
assert t._resample_pcm16(data) == data
|
||||
|
||||
|
||||
def test_resample_pcm16_downsamples() -> None:
|
||||
"""24kHz -> 16kHz should produce fewer samples (ratio 3:2)."""
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [1000, 2000, 3000, 4000, 5000, 6000]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
|
||||
assert len(output_samples) == 4
|
||||
|
||||
|
||||
def test_resample_pcm16_clamps_to_int16_range() -> None:
|
||||
from onyx.voice.providers.elevenlabs import ElevenLabsStreamingTranscriber
|
||||
|
||||
t = ElevenLabsStreamingTranscriber.__new__(ElevenLabsStreamingTranscriber)
|
||||
t.input_sample_rate = 24000
|
||||
t.target_sample_rate = 16000
|
||||
|
||||
input_samples = [32767, -32768, 32767, -32768, 32767, -32768]
|
||||
data = struct.pack(f"<{len(input_samples)}h", *input_samples)
|
||||
|
||||
result = t._resample_pcm16(data)
|
||||
output_samples = struct.unpack(f"<{len(result) // 2}h", result)
|
||||
for s in output_samples:
|
||||
assert -32768 <= s <= 32767
|
||||
|
||||
|
||||
# --- Provider Model Defaulting ---
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_stt_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", stt_model="invalid_model")
|
||||
assert provider.stt_model == "scribe_v1"
|
||||
|
||||
|
||||
def test_provider_defaults_invalid_tts_model() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test", tts_model="invalid_model")
|
||||
assert provider.tts_model == "eleven_multilingual_v2"
|
||||
|
||||
|
||||
def test_provider_accepts_valid_models() -> None:
|
||||
provider = ElevenLabsVoiceProvider(
|
||||
api_key="test", stt_model="scribe_v2_realtime", tts_model="eleven_turbo_v2_5"
|
||||
)
|
||||
assert provider.stt_model == "scribe_v2_realtime"
|
||||
assert provider.tts_model == "eleven_turbo_v2_5"
|
||||
|
||||
|
||||
def test_provider_defaults_api_base() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
assert provider.api_base == DEFAULT_ELEVENLABS_API_BASE
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = ElevenLabsVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -1,97 +0,0 @@
|
||||
import io
|
||||
import struct
|
||||
import wave
|
||||
|
||||
from onyx.voice.providers.openai import _create_wav_header
|
||||
from onyx.voice.providers.openai import _http_to_ws_url
|
||||
from onyx.voice.providers.openai import OpenAIRealtimeMessageType
|
||||
from onyx.voice.providers.openai import OpenAIVoiceProvider
|
||||
|
||||
|
||||
# --- _http_to_ws_url ---
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_https_to_wss() -> None:
|
||||
assert _http_to_ws_url("https://api.openai.com") == "wss://api.openai.com"
|
||||
|
||||
|
||||
def test_http_to_ws_url_converts_http_to_ws() -> None:
|
||||
assert _http_to_ws_url("http://localhost:9090") == "ws://localhost:9090"
|
||||
|
||||
|
||||
def test_http_to_ws_url_passes_through_ws() -> None:
|
||||
assert _http_to_ws_url("wss://already.ws") == "wss://already.ws"
|
||||
|
||||
|
||||
# --- StrEnum comparison ---
|
||||
|
||||
|
||||
def test_realtime_message_type_compares_as_string() -> None:
|
||||
assert str(OpenAIRealtimeMessageType.ERROR) == "error"
|
||||
assert (
|
||||
str(OpenAIRealtimeMessageType.TRANSCRIPTION_DELTA)
|
||||
== "conversation.item.input_audio_transcription.delta"
|
||||
)
|
||||
assert isinstance(OpenAIRealtimeMessageType.ERROR, str)
|
||||
|
||||
|
||||
# --- _create_wav_header ---
|
||||
|
||||
|
||||
def test_wav_header_is_44_bytes() -> None:
|
||||
assert len(_create_wav_header(1000)) == 44
|
||||
|
||||
|
||||
def test_wav_header_chunk_size_matches_data_length() -> None:
|
||||
data_length = 2000
|
||||
header = _create_wav_header(data_length)
|
||||
chunk_size = struct.unpack_from("<I", header, 4)[0]
|
||||
assert chunk_size == 36 + data_length
|
||||
|
||||
|
||||
def test_wav_header_byte_rate() -> None:
|
||||
header = _create_wav_header(100, sample_rate=24000, channels=1, bits_per_sample=16)
|
||||
byte_rate = struct.unpack_from("<I", header, 28)[0]
|
||||
assert byte_rate == 24000 * 1 * 16 // 8
|
||||
|
||||
|
||||
def test_wav_header_produces_valid_wav() -> None:
|
||||
"""Header + PCM data should parse as valid WAV."""
|
||||
data_length = 100
|
||||
pcm_data = b"\x00" * data_length
|
||||
header = _create_wav_header(data_length, sample_rate=24000)
|
||||
|
||||
with wave.open(io.BytesIO(header + pcm_data), "rb") as wav_file:
|
||||
assert wav_file.getnchannels() == 1
|
||||
assert wav_file.getsampwidth() == 2
|
||||
assert wav_file.getframerate() == 24000
|
||||
assert wav_file.getnframes() == data_length // 2
|
||||
|
||||
|
||||
# --- Provider Defaults ---
|
||||
|
||||
|
||||
def test_provider_default_models() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
assert provider.stt_model == "whisper-1"
|
||||
assert provider.tts_model == "tts-1"
|
||||
assert provider.default_voice == "alloy"
|
||||
|
||||
|
||||
def test_provider_custom_models() -> None:
|
||||
provider = OpenAIVoiceProvider(
|
||||
api_key="test",
|
||||
stt_model="gpt-4o-transcribe",
|
||||
tts_model="tts-1-hd",
|
||||
default_voice="nova",
|
||||
)
|
||||
assert provider.stt_model == "gpt-4o-transcribe"
|
||||
assert provider.tts_model == "tts-1-hd"
|
||||
assert provider.default_voice == "nova"
|
||||
|
||||
|
||||
def test_provider_get_available_voices_returns_copy() -> None:
|
||||
provider = OpenAIVoiceProvider(api_key="test")
|
||||
voices = provider.get_available_voices()
|
||||
voices.clear()
|
||||
assert len(provider.get_available_voices()) > 0
|
||||
@@ -38,11 +38,6 @@ services:
|
||||
opensearch:
|
||||
ports:
|
||||
- "9200:9200"
|
||||
# Rootless Docker can reject the base OpenSearch ulimit settings, so clear
|
||||
# the inherited block entirely in the dev override.
|
||||
ulimits: !reset null
|
||||
environment:
|
||||
- bootstrap.memory_lock=false
|
||||
|
||||
inference_model_server:
|
||||
ports:
|
||||
|
||||
@@ -33,7 +33,6 @@ SECRET=
|
||||
|
||||
# OpenID Connect (OIDC)
|
||||
#OPENID_CONFIG_URL=
|
||||
#OIDC_PKCE_ENABLED=
|
||||
|
||||
# SAML config directory for OneLogin compatible setups
|
||||
#SAML_CONF_DIR=
|
||||
|
||||
@@ -167,7 +167,6 @@ LOG_ONYX_MODEL_INTERACTIONS=False
|
||||
# OAUTH_CLIENT_ID=
|
||||
# OAUTH_CLIENT_SECRET=
|
||||
# OPENID_CONFIG_URL=
|
||||
# OIDC_PKCE_ENABLED=
|
||||
# TRACK_EXTERNAL_IDP_EXPIRY=
|
||||
# CORS_ALLOWED_ORIGIN=
|
||||
# INTEGRATION_TESTS_MODE=
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.4.34
|
||||
version: 0.4.33
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
# Values for chart-testing (ct lint/install)
|
||||
# This file is automatically used by ct when running lint and install commands
|
||||
auth:
|
||||
userauth:
|
||||
values:
|
||||
user_auth_secret: "placeholder-for-ci-testing"
|
||||
@@ -1,29 +1,17 @@
|
||||
{{- if hasKey .Values.auth "secretKeys" }}
|
||||
{{- fail "ERROR: Secrets handling has been refactored under 'auth' and must be updated before upgrading to this chart version." }}
|
||||
{{- end }}
|
||||
{{- range $secretKey, $secretContent := .Values.auth }}
|
||||
{{- if and (empty $secretContent.existingSecret) (or (not (hasKey $secretContent "enabled")) $secretContent.enabled) }}
|
||||
{{- $secretName := include "onyx.secretName" $secretContent }}
|
||||
{{- $existingSecret := lookup "v1" "Secret" $.Release.Namespace $secretName }}
|
||||
{{- /* Pre-validate: fail before emitting YAML if any required value is missing */ -}}
|
||||
{{- range $name, $value := $secretContent.values }}
|
||||
{{- if and (empty $value) (not (and $existingSecret (hasKey $existingSecret.data $name))) }}
|
||||
{{- fail (printf "Secret value for '%s' is required but not set and no existing secret found. Please set auth.%s.values.%s in values.yaml" $name $secretKey $name) }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- range $secretContent := .Values.auth }}
|
||||
{{- if and (empty $secretContent.existingSecret) (ne ($secretContent.enabled | default true) false) }}
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: {{ $secretName }}
|
||||
name: {{ include "onyx.secretName" $secretContent }}
|
||||
type: Opaque
|
||||
stringData:
|
||||
{{- range $name, $value := $secretContent.values }}
|
||||
{{- if not (empty $value) }}
|
||||
{{- range $name, $value := $secretContent.values }}
|
||||
{{ $name }}: {{ $value | quote }}
|
||||
{{- else if and $existingSecret (hasKey $existingSecret.data $name) }}
|
||||
{{ $name }}: {{ index $existingSecret.data $name | b64dec | quote }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
@@ -1183,28 +1183,10 @@ auth:
|
||||
values:
|
||||
opensearch_admin_username: "admin"
|
||||
opensearch_admin_password: "OnyxDev1!"
|
||||
userauth:
|
||||
# -- Used for signing password reset tokens, email verification tokens, and JWT tokens.
|
||||
enabled: true
|
||||
# -- Overwrite the default secret name, ignored if existingSecret is defined
|
||||
secretName: 'onyx-userauth'
|
||||
# -- Use a secret specified elsewhere
|
||||
existingSecret: ""
|
||||
# -- This defines the env var to secret map
|
||||
secretKeys:
|
||||
USER_AUTH_SECRET: user_auth_secret
|
||||
# -- Secret value. Required - generate with: openssl rand -hex 32
|
||||
# If not set, helm install/upgrade will fail.
|
||||
values:
|
||||
user_auth_secret: ""
|
||||
|
||||
configMap:
|
||||
# Auth type: "basic" (default), "google_oauth", "oidc", or "saml"
|
||||
# UPGRADE NOTE: Default changed from "disabled" to "basic" in 0.4.34.
|
||||
# You must also set auth.userauth.values.user_auth_secret.
|
||||
AUTH_TYPE: "basic"
|
||||
# Enable PKCE for OIDC login flow. Leave empty/false for backward compatibility.
|
||||
OIDC_PKCE_ENABLED: ""
|
||||
# Change this for production uses unless Onyx is only accessible behind VPN
|
||||
AUTH_TYPE: "disabled"
|
||||
# 1 Day Default
|
||||
SESSION_EXPIRE_TIME_SECONDS: "86400"
|
||||
# Can be something like onyx.app, as an extra double-check
|
||||
|
||||
@@ -52,10 +52,6 @@
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"content": "Use `contributing_guides/best_practices.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly."
|
||||
}
|
||||
],
|
||||
"rules": [
|
||||
@@ -75,14 +71,6 @@
|
||||
"scope": [],
|
||||
"rule": "When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies."
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"rule": "Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments."
|
||||
},
|
||||
{
|
||||
"scope": ["backend/**/*.py"],
|
||||
"rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"message\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`."
|
||||
@@ -98,21 +86,6 @@
|
||||
"scope": [],
|
||||
"path": "CLAUDE.md",
|
||||
"description": "Project instructions and coding standards"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "backend/alembic/README.md",
|
||||
"description": "Migration guidance, including multi-tenant migration behavior"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/helm/charts/onyx/values-lite.yaml",
|
||||
"description": "Lite deployment Helm values and service assumptions"
|
||||
},
|
||||
{
|
||||
"scope": [],
|
||||
"path": "deployment/docker_compose/docker-compose.onyx-lite.yml",
|
||||
"description": "Lite deployment Docker Compose overlay and disabled service behavior"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -35,7 +35,6 @@ backend = [
|
||||
"alembic==1.10.4",
|
||||
"asyncpg==0.30.0",
|
||||
"atlassian-python-api==3.41.16",
|
||||
"azure-cognitiveservices-speech==1.38.0",
|
||||
"beautifulsoup4==4.12.3",
|
||||
"boto3==1.39.11",
|
||||
"boto3-stubs[s3]==1.39.11",
|
||||
@@ -92,7 +91,7 @@ backend = [
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypandoc_binary==1.16.2",
|
||||
"pypdf==6.8.0",
|
||||
"pypdf==6.7.5",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
@@ -144,7 +143,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.7.0",
|
||||
"onyx-devtools==0.6.3",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
@@ -154,7 +153,7 @@ dev = [
|
||||
"pytest-repeat==0.9.4",
|
||||
"pytest-xdist==3.8.0",
|
||||
"pytest==8.3.5",
|
||||
"release-tag==0.5.2",
|
||||
"release-tag==0.4.3",
|
||||
"reorder-python-imports-black==3.14.0",
|
||||
"ruff==0.12.0",
|
||||
"types-beautifulsoup4==4.12.0.3",
|
||||
|
||||
@@ -25,9 +25,6 @@ Some commands require external tools to be installed and configured:
|
||||
- **Docker** - Required for `compose`, `logs`, and `pull` commands
|
||||
- Install from [docker.com](https://docs.docker.com/get-docker/)
|
||||
|
||||
- **uv** - Required for `backend` commands
|
||||
- Install from [docs.astral.sh/uv](https://docs.astral.sh/uv/)
|
||||
|
||||
- **GitHub CLI** (`gh`) - Required for `run-ci` and `cherry-pick` commands
|
||||
- Install from [cli.github.com](https://cli.github.com/)
|
||||
- Authenticate with `gh auth login`
|
||||
@@ -173,53 +170,6 @@ ods pull
|
||||
ods pull --tag edge
|
||||
```
|
||||
|
||||
### `backend` - Run Backend Services
|
||||
|
||||
Run backend services (API server, model server) with environment loaded from
|
||||
`.vscode/.env`. On first run, copies `.vscode/env_template.txt` to `.vscode/.env`
|
||||
if the `.env` file does not already exist.
|
||||
|
||||
Enterprise Edition features are enabled by default with license enforcement
|
||||
disabled, matching the `compose` command behavior.
|
||||
|
||||
```shell
|
||||
ods backend <subcommand>
|
||||
```
|
||||
|
||||
**Subcommands:**
|
||||
|
||||
- `api` - Start the FastAPI backend server (`uvicorn onyx.main:app --reload`)
|
||||
- `model_server` - Start the model server (`uvicorn model_server.main:app --reload`)
|
||||
|
||||
**Flags:**
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `--no-ee` | `false` | Disable Enterprise Edition features (enabled by default) |
|
||||
| `--port` | `8080` (api) / `9000` (model_server) | Port to listen on |
|
||||
|
||||
Shell environment takes precedence over `.env` file values, so inline overrides
|
||||
work as expected (e.g. `S3_ENDPOINT_URL=foo ods backend api`).
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Start the API server
|
||||
ods backend api
|
||||
|
||||
# Start the API server on a custom port
|
||||
ods backend api --port 9090
|
||||
|
||||
# Start without Enterprise Edition
|
||||
ods backend api --no-ee
|
||||
|
||||
# Start the model server
|
||||
ods backend model_server
|
||||
|
||||
# Start the model server on a custom port
|
||||
ods backend model_server --port 9001
|
||||
```
|
||||
|
||||
### `web` - Run Frontend Scripts
|
||||
|
||||
Run npm scripts from `web/package.json` without manually changing directories.
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
)
|
||||
|
||||
// NewBackendCommand creates the parent "backend" command with subcommands for
|
||||
// running backend services.
|
||||
// BackendOptions holds options shared across backend subcommands.
|
||||
type BackendOptions struct {
|
||||
NoEE bool
|
||||
}
|
||||
|
||||
func NewBackendCommand() *cobra.Command {
|
||||
opts := &BackendOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "backend",
|
||||
Short: "Run backend services (api, model_server)",
|
||||
Long: `Run backend services with environment from .vscode/.env.
|
||||
|
||||
On first run, copies .vscode/env_template.txt to .vscode/.env if the
|
||||
.env file does not already exist.
|
||||
|
||||
Enterprise Edition features are enabled by default for development,
|
||||
with license enforcement disabled.
|
||||
|
||||
Available subcommands:
|
||||
api Start the FastAPI backend server
|
||||
model_server Start the model server`,
|
||||
}
|
||||
|
||||
cmd.PersistentFlags().BoolVar(&opts.NoEE, "no-ee", false, "Disable Enterprise Edition features (enabled by default)")
|
||||
|
||||
cmd.AddCommand(newBackendAPICommand(opts))
|
||||
cmd.AddCommand(newBackendModelServerCommand(opts))
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newBackendAPICommand(opts *BackendOptions) *cobra.Command {
|
||||
var port string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "api",
|
||||
Short: "Start the backend API server (uvicorn with hot-reload)",
|
||||
Long: `Start the backend API server using uvicorn with hot-reload.
|
||||
|
||||
Examples:
|
||||
ods backend api
|
||||
ods backend api --port 9090
|
||||
ods backend api --no-ee`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runBackendService("api", "onyx.main:app", port, opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&port, "port", "8080", "Port to listen on")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func newBackendModelServerCommand(opts *BackendOptions) *cobra.Command {
|
||||
var port string
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "model_server",
|
||||
Short: "Start the model server (uvicorn with hot-reload)",
|
||||
Long: `Start the model server using uvicorn with hot-reload.
|
||||
|
||||
Examples:
|
||||
ods backend model_server
|
||||
ods backend model_server --port 9001`,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runBackendService("model_server", "model_server.main:app", port, opts)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&port, "port", "9000", "Port to listen on")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runBackendService(name, module, port string, opts *BackendOptions) {
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find git root: %v", err)
|
||||
}
|
||||
|
||||
envFile := ensureBackendEnvFile(root)
|
||||
fileVars := loadBackendEnvFile(envFile)
|
||||
|
||||
eeDefaults := eeEnvDefaults(opts.NoEE)
|
||||
fileVars = append(fileVars, eeDefaults...)
|
||||
|
||||
backendDir := filepath.Join(root, "backend")
|
||||
|
||||
uvicornArgs := []string{
|
||||
"run", "uvicorn", module,
|
||||
"--reload",
|
||||
"--port", port,
|
||||
}
|
||||
log.Infof("Starting %s on port %s...", name, port)
|
||||
if !opts.NoEE {
|
||||
log.Info("Enterprise Edition enabled (use --no-ee to disable)")
|
||||
}
|
||||
log.Debugf("Running in %s: uv %v", backendDir, uvicornArgs)
|
||||
|
||||
mergedEnv := mergeEnv(os.Environ(), fileVars)
|
||||
log.Debugf("Applied %d env vars from %s (shell takes precedence)", len(fileVars), envFile)
|
||||
|
||||
svcCmd := exec.Command("uv", uvicornArgs...)
|
||||
svcCmd.Dir = backendDir
|
||||
svcCmd.Stdout = os.Stdout
|
||||
svcCmd.Stderr = os.Stderr
|
||||
svcCmd.Stdin = os.Stdin
|
||||
svcCmd.Env = mergedEnv
|
||||
|
||||
if err := svcCmd.Run(); err != nil {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
if code := exitErr.ExitCode(); code != -1 {
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
log.Fatalf("Failed to run %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
// eeEnvDefaults returns env entries for EE and license enforcement settings.
|
||||
// These are appended to the file vars so they act as defaults — shell env
|
||||
// and .env file values still take precedence via mergeEnv.
|
||||
func eeEnvDefaults(noEE bool) []string {
|
||||
if noEE {
|
||||
return []string{
|
||||
"ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false",
|
||||
}
|
||||
}
|
||||
return []string{
|
||||
"ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true",
|
||||
"LICENSE_ENFORCEMENT_ENABLED=false",
|
||||
}
|
||||
}
|
||||
|
||||
// ensureBackendEnvFile copies env_template.txt to .env if .env doesn't exist.
|
||||
func ensureBackendEnvFile(root string) string {
|
||||
vscodeDir := filepath.Join(root, ".vscode")
|
||||
envFile := filepath.Join(vscodeDir, ".env")
|
||||
templateFile := filepath.Join(vscodeDir, "env_template.txt")
|
||||
|
||||
if _, err := os.Stat(envFile); err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Fatalf("Failed to stat env file %s: %v", envFile, err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("Using existing env file: %s", envFile)
|
||||
return envFile
|
||||
}
|
||||
|
||||
templateData, err := os.ReadFile(templateFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read env template %s: %v", templateFile, err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(vscodeDir, 0755); err != nil {
|
||||
log.Fatalf("Failed to create .vscode directory: %v", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(envFile, templateData, 0644); err != nil {
|
||||
log.Fatalf("Failed to write env file %s: %v", envFile, err)
|
||||
}
|
||||
|
||||
log.Infof("Created %s from template (review and fill in <REPLACE THIS> values)", envFile)
|
||||
return envFile
|
||||
}
|
||||
|
||||
// mergeEnv combines shell environment with file-based defaults. Shell values
|
||||
// take precedence — file entries are only added for keys not already present.
|
||||
func mergeEnv(shellEnv, fileVars []string) []string {
|
||||
existing := make(map[string]bool, len(shellEnv))
|
||||
for _, entry := range shellEnv {
|
||||
if idx := strings.Index(entry, "="); idx > 0 {
|
||||
existing[entry[:idx]] = true
|
||||
}
|
||||
}
|
||||
|
||||
merged := make([]string, len(shellEnv))
|
||||
copy(merged, shellEnv)
|
||||
for _, entry := range fileVars {
|
||||
if idx := strings.Index(entry, "="); idx > 0 {
|
||||
key := entry[:idx]
|
||||
if !existing[key] {
|
||||
merged = append(merged, entry)
|
||||
} else {
|
||||
log.Debugf("Env var %s already set in shell, skipping .env value", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// loadBackendEnvFile parses a .env file into KEY=VALUE entries suitable for
|
||||
// appending to os.Environ(). Blank lines and comments are skipped.
|
||||
func loadBackendEnvFile(path string) []string {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to open env file %s: %v", path, err)
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
var envVars []string
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if idx := strings.Index(line, "="); idx > 0 {
|
||||
key := strings.TrimSpace(line[:idx])
|
||||
value := strings.TrimSpace(line[idx+1:])
|
||||
value = strings.Trim(value, `"'`)
|
||||
envVars = append(envVars, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Fatalf("Failed to read env file %s: %v", path, err)
|
||||
}
|
||||
|
||||
return envVars
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/jmelahman/tag/git"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// NewLatestStableTagCommand creates the latest-stable-tag command.
|
||||
func NewLatestStableTagCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "latest-stable-tag",
|
||||
Short: "Print the git tag that should receive the 'latest' Docker tag",
|
||||
Long: `Print the highest stable (non-pre-release) semver tag in the repository.
|
||||
|
||||
This is used during deployment to decide whether a given tag should
|
||||
receive the "latest" tag on Docker Hub. Only the highest vX.Y.Z tag
|
||||
qualifies. Tags with pre-release suffixes (e.g. v1.2.3-beta,
|
||||
v1.2.3-cloud.1) are excluded.`,
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(c *cobra.Command, _ []string) error {
|
||||
tag, err := git.GetLatestStableSemverTag("")
|
||||
if err != nil {
|
||||
return fmt.Errorf("get latest stable semver tag: %w", err)
|
||||
}
|
||||
if tag == "" {
|
||||
return fmt.Errorf("no stable semver tag found in repository")
|
||||
}
|
||||
fmt.Println(tag)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -41,7 +41,6 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.PersistentFlags().BoolVar(&opts.Debug, "debug", false, "run in debug mode")
|
||||
|
||||
// Add subcommands
|
||||
cmd.AddCommand(NewBackendCommand())
|
||||
cmd.AddCommand(NewCheckLazyImportsCommand())
|
||||
cmd.AddCommand(NewCherryPickCommand())
|
||||
cmd.AddCommand(NewDBCommand())
|
||||
@@ -53,7 +52,6 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewDesktopCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
cmd.AddCommand(NewLatestStableTagCommand())
|
||||
cmd.AddCommand(NewWhoisCommand())
|
||||
|
||||
return cmd
|
||||
|
||||
@@ -3,13 +3,12 @@ module github.com/onyx-dot-app/onyx/tools/ods
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
github.com/jmelahman/tag v0.5.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/spf13/pflag v1.0.10
|
||||
github.com/spf13/cobra v1.10.1
|
||||
github.com/spf13/pflag v1.0.9
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 // indirect
|
||||
)
|
||||
|
||||
@@ -4,26 +4,20 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jmelahman/tag v0.5.2 h1:g6A/aHehu5tkA31mPoDsXBNr1FigZ9A82Y8WVgb/WsM=
|
||||
github.com/jmelahman/tag v0.5.2/go.mod h1:qmuqk19B1BKkpcg3kn7l/Eey+UqucLxgOWkteUGiG4Q=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
|
||||
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
|
||||
github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s=
|
||||
github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0=
|
||||
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
|
||||
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
206
uv.lock
generated
206
uv.lock
generated
@@ -463,19 +463,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/00/3ed12264094ec91f534fae429945efbaa9f8c666f3aa7061cc3b2a26a0cd/authlib-1.6.7-py2.py3-none-any.whl", hash = "sha256:c637340d9a02789d2efa1d003a7437d10d3e565237bcb5fcbc6c134c7b95bab0", size = 244115, upload-time = "2026-02-06T14:04:12.141Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "azure-cognitiveservices-speech"
|
||||
version = "1.38.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/85/f4/4571c42cb00f8af317d5431f594b4ece1fbe59ab59f106947fea8e90cf89/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:18dce915ab032711f687abb3297dd19176b9cbea562b322ee6fa7365ef4a5091", size = 6775838, upload-time = "2024-06-11T03:08:35.202Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/86/22/0ca2c59a573119950cad1f53531fec9872fc38810c405a4e1827f3d13a8e/azure_cognitiveservices_speech-1.38.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9dd0800fbc4a8438c6dfd5747a658251914fe2d205a29e9b46158cadac6ab381", size = 6687975, upload-time = "2024-06-11T03:08:38.797Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/96/5436c09de3af3a9aefaa8cc00533c3a0f5d17aef5bbe017c17f0a30ad66e/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:1c344e8a6faadb063cea451f0301e13b44d9724e1242337039bff601e81e6f86", size = 40022287, upload-time = "2024-06-11T03:08:16.777Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/2d/ba20d05ff77ec9870cd489e6e7a474ba7fe820524bcf6fd202025e0c11cf/azure_cognitiveservices_speech-1.38.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1e002595a749471efeac3a54c80097946570b76c13049760b97a4b881d9d24af", size = 39788653, upload-time = "2024-06-11T03:08:30.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/21/25f8c37fb6868db4346ca977c287ede9e87f609885d932653243c9ed5f63/azure_cognitiveservices_speech-1.38.0-py3-none-win32.whl", hash = "sha256:16a530e6c646eb49ea0bc05cb45a9d28b99e4b67613f6c3a6c54e26e6bf65241", size = 1428364, upload-time = "2024-06-11T03:08:03.965Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/05/a6414a3481c5ee30c4f32742abe055e5f3ce4ff69e936089d86ece354067/azure_cognitiveservices_speech-1.38.0-py3-none-win_amd64.whl", hash = "sha256:1d38d8c056fb3f513a9ff27ab4e77fd08ca487f8788cc7a6df772c1ab2c97b54", size = 1539297, upload-time = "2024-06-11T03:08:01.304Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.17.0"
|
||||
@@ -4240,7 +4227,6 @@ backend = [
|
||||
{ name = "asana" },
|
||||
{ name = "asyncpg" },
|
||||
{ name = "atlassian-python-api" },
|
||||
{ name = "azure-cognitiveservices-speech" },
|
||||
{ name = "beautifulsoup4" },
|
||||
{ name = "boto3" },
|
||||
{ name = "boto3-stubs", extra = ["s3"] },
|
||||
@@ -4395,7 +4381,6 @@ requires-dist = [
|
||||
{ name = "asana", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "asyncpg", marker = "extra == 'backend'", specifier = "==0.30.0" },
|
||||
{ name = "atlassian-python-api", marker = "extra == 'backend'", specifier = "==3.41.16" },
|
||||
{ name = "azure-cognitiveservices-speech", marker = "extra == 'backend'", specifier = "==1.38.0" },
|
||||
{ name = "beautifulsoup4", marker = "extra == 'backend'", specifier = "==4.12.3" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = "==25.1.0" },
|
||||
{ name = "boto3", marker = "extra == 'backend'", specifier = "==1.39.11" },
|
||||
@@ -4458,7 +4443,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.7.0" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.3" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4481,7 +4466,7 @@ requires-dist = [
|
||||
{ name = "pygithub", marker = "extra == 'backend'", specifier = "==2.5.0" },
|
||||
{ name = "pympler", marker = "extra == 'backend'", specifier = "==1.1" },
|
||||
{ name = "pypandoc-binary", marker = "extra == 'backend'", specifier = "==1.16.2" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.8.0" },
|
||||
{ name = "pypdf", marker = "extra == 'backend'", specifier = "==6.7.5" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = "==8.3.5" },
|
||||
{ name = "pytest-alembic", marker = "extra == 'dev'", specifier = "==0.12.1" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = "==1.3.0" },
|
||||
@@ -4500,7 +4485,7 @@ requires-dist = [
|
||||
{ name = "pywikibot", marker = "extra == 'backend'", specifier = "==9.0.0" },
|
||||
{ name = "rapidfuzz", marker = "extra == 'backend'", specifier = "==3.13.0" },
|
||||
{ name = "redis", marker = "extra == 'backend'", specifier = "==5.0.8" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.5.2" },
|
||||
{ name = "release-tag", marker = "extra == 'dev'", specifier = "==0.4.3" },
|
||||
{ name = "reorder-python-imports-black", marker = "extra == 'dev'", specifier = "==3.14.0" },
|
||||
{ name = "requests", marker = "extra == 'backend'", specifier = "==2.32.5" },
|
||||
{ name = "requests-oauthlib", marker = "extra == 'backend'", specifier = "==1.3.1" },
|
||||
@@ -4563,19 +4548,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.7.0"
|
||||
version = "0.6.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/22/9e/6957b11555da57d9e97092f4cd8ac09a86666264b0c9491838f4b27db5dc/onyx_devtools-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ad962a168d46ea11dcde9fa3b37e4f12ec520b4a4cb4d49d8732de110d46c4b6", size = 3998057, upload-time = "2026-03-12T03:09:11.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/90/c72f3d06ba677012d77c77de36195b6a32a15c755c79ba0282be74e3c366/onyx_devtools-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e46d252e2b048ff053b03519c3a875998780738d7c334eaa1c9a32ff445e3e1a", size = 3687753, upload-time = "2026-03-12T03:09:11.742Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/42/4e9fe36eccf9f76d67ba8f4ff6539196a09cd60351fb63f5865e1544cbfa/onyx_devtools-0.7.0-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:f280bc9320e1cc310e7d753a371009bfaab02cc0e0cfd78559663b15655b5a50", size = 3560144, upload-time = "2026-03-12T03:12:24.02Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/40/36dc12d99760b358c7f39b27361cb18fa9681ffe194107f982d0e1a74016/onyx_devtools-0.7.0-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:e31df751c7540ae7e70a7fe8e1153c79c31c2254af6aa4c72c0dd54fa381d2ab", size = 3964387, upload-time = "2026-03-12T03:09:11.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/18/74744230c3820a5a7687335507ca5f1dbebab2c5325805041c1cd5703e6a/onyx_devtools-0.7.0-py3-none-win_amd64.whl", hash = "sha256:541bfd347c2d5b11e7f63ab5001d2594df91d215ad9d07b1562f5e715700f7e6", size = 4068030, upload-time = "2026-03-12T03:09:12.98Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/78/1320436607d3ffcb321ba7b064556c020ea15843a7e7d903fbb7529a71f5/onyx_devtools-0.7.0-py3-none-win_arm64.whl", hash = "sha256:83016330a9d39712431916cc25b2fb2cfcaa0112a55cc4f919d545da3a8974f9", size = 3626409, upload-time = "2026-03-12T03:09:10.222Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/e2/e7619722c3ccd18eb38100f776fb3dd6b4ae0fbbee09fca5af7c69a279b5/onyx_devtools-0.6.3-py3-none-any.whl", hash = "sha256:d3a5422945d9da12cafc185f64b39f6e727ee4cc92b37427deb7a38f9aad4966", size = 3945381, upload-time = "2026-03-05T20:39:25.896Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/09/513d2dabedc1e54ad4376830fc9b34a3d9c164bdbcdedfcdbb8b8154dc5a/onyx_devtools-0.6.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:efe300e9f3a2e7ae75f88a4f9e0a5c4c471478296cb1615b6a1f03d247582e13", size = 3978761, upload-time = "2026-03-05T20:39:28.822Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/41/e757602a0de032d74ed01c7ee57f30e57728fb9cd4f922f50d2affda3889/onyx_devtools-0.6.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:594066eed3f917cfab5a8c7eac3d4a210df30259f2049f664787749709345e19", size = 3665378, upload-time = "2026-03-05T20:44:22.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/1c/c93b65d0b32e202596a2647922a75c7011cb982f899ddfcfd171f792c58f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:384ef66030b55c0fd68b3898782b5b4b868ff3de119569dfc8544e2ce534b98a", size = 3540890, upload-time = "2026-03-05T20:39:28.886Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f4/33/760eb656013f7f0cdff24570480d3dc4e52bbd8e6147ea1e8cf6fad7554f/onyx_devtools-0.6.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:82e218f3a49f64910c2c4c34d5dc12d1ea1520a27e0b0f6e4c0949ff9abaf0e1", size = 3945396, upload-time = "2026-03-05T20:39:34.323Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/eb/f54b3675c464df8a51194ff75afc97c2417659e3a209dc46948b47c28860/onyx_devtools-0.6.3-py3-none-win_amd64.whl", hash = "sha256:8af614ae7229290ef2417cb85270184a1e826ed9a3a34658da93851edb36df57", size = 4045936, upload-time = "2026-03-05T20:39:28.375Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/b8/5bee38e748f3d4b8ec935766224db1bbc1214c91092e5822c080fccd9130/onyx_devtools-0.6.3-py3-none-win_arm64.whl", hash = "sha256:717589db4b42528d33ae96f8006ee6aad3555034dcfee724705b6576be6a6ec4", size = 3608268, upload-time = "2026-03-05T20:39:28.731Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4754,70 +4740,70 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "orjson"
|
||||
version = "3.11.6"
|
||||
version = "3.11.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/70/a3/4e09c61a5f0c521cba0bb433639610ae037437669f1a4cbc93799e731d78/orjson-3.11.6.tar.gz", hash = "sha256:0a54c72259f35299fd033042367df781c2f66d10252955ca1efb7db309b954cb", size = 6175856, upload-time = "2026-01-29T15:13:07.942Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c6/fe/ed708782d6709cc60eb4c2d8a361a440661f74134675c72990f2c48c785f/orjson-3.11.4.tar.gz", hash = "sha256:39485f4ab4c9b30a3943cfe99e1a213c4776fb69e8abd68f66b83d5a0b0fdc6d", size = 5945188, upload-time = "2025-10-24T15:50:38.027Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/fd/d6b0a36854179b93ed77839f107c4089d91cccc9f9ba1b752b6e3bac5f34/orjson-3.11.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e259e85a81d76d9665f03d6129e09e4435531870de5961ddcd0bf6e3a7fde7d7", size = 250029, upload-time = "2026-01-29T15:11:35.942Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a3/bb/22902619826641cf3b627c24aab62e2ad6b571bdd1d34733abb0dd57f67a/orjson-3.11.6-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:52263949f41b4a4822c6b1353bcc5ee2f7109d53a3b493501d3369d6d0e7937a", size = 134518, upload-time = "2026-01-29T15:11:37.347Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/90/7a818da4bba1de711a9653c420749c0ac95ef8f8651cbc1dca551f462fe0/orjson-3.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6439e742fa7834a24698d358a27346bb203bff356ae0402e7f5df8f749c621a8", size = 137917, upload-time = "2026-01-29T15:11:38.511Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/0f/02846c1cac8e205cb3822dd8aa8f9114acda216f41fd1999ace6b543418d/orjson-3.11.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b81ffd68f084b4e993e3867acb554a049fa7787cc8710bbcc1e26965580d99be", size = 134923, upload-time = "2026-01-29T15:11:39.711Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/94/cf/aeaf683001b474bb3c3c757073a4231dfdfe8467fceaefa5bfd40902c99f/orjson-3.11.6-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a5a5468e5e60f7ef6d7f9044b06c8f94a3c56ba528c6e4f7f06ae95164b595ec", size = 140752, upload-time = "2026-01-29T15:11:41.347Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/fe/dad52d8315a65f084044a0819d74c4c9daf9ebe0681d30f525b0d29a31f0/orjson-3.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:72c5005eb45bd2535632d4f3bec7ad392832cfc46b62a3021da3b48a67734b45", size = 144201, upload-time = "2026-01-29T15:11:42.537Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/bc/ab070dd421565b831801077f1e390c4d4af8bfcecafc110336680a33866b/orjson-3.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0b14dd49f3462b014455a28a4d810d3549bf990567653eb43765cd847df09145", size = 142380, upload-time = "2026-01-29T15:11:44.309Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/d8/4b581c725c3a308717f28bf45a9fdac210bca08b67e8430143699413ff06/orjson-3.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bb2c1ea30ef302f0f89f9bf3e7f9ab5e2af29dc9f80eb87aa99788e4e2d65", size = 145582, upload-time = "2026-01-29T15:11:45.506Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/a2/09aab99b39f9a7f175ea8fa29adb9933a3d01e7d5d603cdee7f1c40c8da2/orjson-3.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:825e0a85d189533c6bff7e2fc417a28f6fcea53d27125c4551979aecd6c9a197", size = 147270, upload-time = "2026-01-29T15:11:46.782Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/2f/5ef8eaf7829dc50da3bf497c7775b21ee88437bc8c41f959aa3504ca6631/orjson-3.11.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:b04575417a26530637f6ab4b1f7b4f666eb0433491091da4de38611f97f2fcf3", size = 421222, upload-time = "2026-01-29T15:11:48.106Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/b0/dd6b941294c2b5b13da5fdc7e749e58d0c55a5114ab37497155e83050e95/orjson-3.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b83eb2e40e8c4da6d6b340ee6b1d6125f5195eb1b0ebb7eac23c6d9d4f92d224", size = 155562, upload-time = "2026-01-29T15:11:49.408Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/09/43924331a847476ae2f9a16bd6d3c9dab301265006212ba0d3d7fd58763a/orjson-3.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1f42da604ee65a6b87eef858c913ce3e5777872b19321d11e6fc6d21de89b64f", size = 147432, upload-time = "2026-01-29T15:11:50.635Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/e9/d9865961081816909f6b49d880749dbbd88425afd7c5bbce0549e2290d77/orjson-3.11.6-cp311-cp311-win32.whl", hash = "sha256:5ae45df804f2d344cffb36c43fdf03c82fb6cd247f5faa41e21891b40dfbf733", size = 139623, upload-time = "2026-01-29T15:11:51.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/f9/6836edb92f76eec1082919101eb1145d2f9c33c8f2c5e6fa399b82a2aaa8/orjson-3.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:f4295948d65ace0a2d8f2c4ccc429668b7eb8af547578ec882e16bf79b0050b2", size = 136647, upload-time = "2026-01-29T15:11:53.454Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/0c/4954082eea948c9ae52ee0bcbaa2f99da3216a71bcc314ab129bde22e565/orjson-3.11.6-cp311-cp311-win_arm64.whl", hash = "sha256:314e9c45e0b81b547e3a1cfa3df3e07a815821b3dac9fe8cb75014071d0c16a4", size = 135327, upload-time = "2026-01-29T15:11:56.616Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/14/ba/759f2879f41910b7e5e0cdbd9cf82a4f017c527fb0e972e9869ca7fe4c8e/orjson-3.11.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6f03f30cd8953f75f2a439070c743c7336d10ee940da918d71c6f3556af3ddcf", size = 249988, upload-time = "2026-01-29T15:11:58.294Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/70/54cecb929e6c8b10104fcf580b0cc7dc551aa193e83787dd6f3daba28bb5/orjson-3.11.6-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:af44baae65ef386ad971469a8557a0673bb042b0b9fd4397becd9c2dfaa02588", size = 134445, upload-time = "2026-01-29T15:11:59.819Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/6f/ec0309154457b9ba1ad05f11faa4441f76037152f75e1ac577db3ce7ca96/orjson-3.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c310a48542094e4f7dbb6ac076880994986dda8ca9186a58c3cb70a3514d3231", size = 137708, upload-time = "2026-01-29T15:12:01.488Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/20/52/3c71b80840f8bab9cb26417302707b7716b7d25f863f3a541bcfa232fe6e/orjson-3.11.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d8dfa7a5d387f15ecad94cb6b2d2d5f4aeea64efd8d526bfc03c9812d01e1cc0", size = 134798, upload-time = "2026-01-29T15:12:02.705Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/30/51/b490a43b22ff736282360bd02e6bded455cf31dfc3224e01cd39f919bbd2/orjson-3.11.6-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba8daee3e999411b50f8b50dbb0a3071dd1845f3f9a1a0a6fa6de86d1689d84d", size = 140839, upload-time = "2026-01-29T15:12:03.956Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/bc/4bcfe4280c1bc63c5291bb96f98298845b6355da2226d3400e17e7b51e53/orjson-3.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f89d104c974eafd7436d7a5fdbc57f7a1e776789959a2f4f1b2eab5c62a339f4", size = 144080, upload-time = "2026-01-29T15:12:05.151Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/74/22970f9ead9ab1f1b5f8c227a6c3aa8d71cd2c5acd005868a1d44f2362fa/orjson-3.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2e2e2456788ca5ea75616c40da06fc885a7dc0389780e8a41bf7c5389ba257b", size = 142435, upload-time = "2026-01-29T15:12:06.641Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/34/d564aff85847ab92c82ee43a7a203683566c2fca0723a5f50aebbe759603/orjson-3.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a42efebc45afabb1448001e90458c4020d5c64fbac8a8dc4045b777db76cb5a", size = 145631, upload-time = "2026-01-29T15:12:08.351Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/ef/016957a3890752c4aa2368326ea69fa53cdc1fdae0a94a542b6410dbdf52/orjson-3.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:71b7cbef8471324966c3738c90ba38775563ef01b512feb5ad4805682188d1b9", size = 147058, upload-time = "2026-01-29T15:12:10.023Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/cc/9a899c3972085645b3225569f91a30e221f441e5dc8126e6d060b971c252/orjson-3.11.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:f8515e5910f454fe9a8e13c2bb9dc4bae4c1836313e967e72eb8a4ad874f0248", size = 421161, upload-time = "2026-01-29T15:12:11.308Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/21/a8/767d3fbd6d9b8fdee76974db40619399355fd49bf91a6dd2c4b6909ccf05/orjson-3.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:300360edf27c8c9bf7047345a94fddf3a8b8922df0ff69d71d854a170cb375cf", size = 155757, upload-time = "2026-01-29T15:12:12.776Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/0b/205cd69ac87e2272e13ef3f5f03a3d4657e317e38c1b08aaa2ef97060bbc/orjson-3.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:caaed4dad39e271adfadc106fab634d173b2bb23d9cf7e67bd645f879175ebfc", size = 147446, upload-time = "2026-01-29T15:12:14.166Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/c5/dd9f22aa9f27c54c7d05cc32f4580c9ac9b6f13811eeb81d6c4c3f50d6b1/orjson-3.11.6-cp312-cp312-win32.whl", hash = "sha256:955368c11808c89793e847830e1b1007503a5923ddadc108547d3b77df761044", size = 139717, upload-time = "2026-01-29T15:12:15.7Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/a1/e62fc50d904486970315a1654b8cfb5832eb46abb18cd5405118e7e1fc79/orjson-3.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:2c68de30131481150073d90a5d227a4a421982f42c025ecdfb66157f9579e06f", size = 136711, upload-time = "2026-01-29T15:12:17.055Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/3d/b4fefad8bdf91e0fe212eb04975aeb36ea92997269d68857efcc7eb1dda3/orjson-3.11.6-cp312-cp312-win_arm64.whl", hash = "sha256:65dfa096f4e3a5e02834b681f539a87fbe85adc82001383c0db907557f666bfc", size = 135212, upload-time = "2026-01-29T15:12:18.3Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/45/d9c71c8c321277bc1ceebf599bc55ba826ae538b7c61f287e9a7e71bd589/orjson-3.11.6-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:e4ae1670caabb598a88d385798692ce2a1b2f078971b3329cfb85253c6097f5b", size = 249828, upload-time = "2026-01-29T15:12:20.14Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/7e/4afcf4cfa9c2f93846d70eee9c53c3c0123286edcbeb530b7e9bd2aea1b2/orjson-3.11.6-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:2c6b81f47b13dac2caa5d20fbc953c75eb802543abf48403a4703ed3bff225f0", size = 134339, upload-time = "2026-01-29T15:12:22.01Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/10/6d2b8a064c8d2411d3d0ea6ab43125fae70152aef6bea77bb50fa54d4097/orjson-3.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:647d6d034e463764e86670644bdcaf8e68b076e6e74783383b01085ae9ab334f", size = 137662, upload-time = "2026-01-29T15:12:23.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/50/5804ea7d586baf83ee88969eefda97a24f9a5bdba0727f73e16305175b26/orjson-3.11.6-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8523b9cc4ef174ae52414f7699e95ee657c16aa18b3c3c285d48d7966cce9081", size = 134626, upload-time = "2026-01-29T15:12:25.099Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/2e/f0492ed43e376722bb4afd648e06cc1e627fc7ec8ff55f6ee739277813ea/orjson-3.11.6-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:313dfd7184cde50c733fc0d5c8c0e2f09017b573afd11dc36bd7476b30b4cb17", size = 140873, upload-time = "2026-01-29T15:12:26.369Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/10/15/6f874857463421794a303a39ac5494786ad46a4ab46d92bda6705d78c5aa/orjson-3.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:905ee036064ff1e1fd1fb800055ac477cdcb547a78c22c1bc2bbf8d5d1a6fb42", size = 144044, upload-time = "2026-01-29T15:12:28.082Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/c7/b7223a3a70f1d0cc2d86953825de45f33877ee1b124a91ca1f79aa6e643f/orjson-3.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ce374cb98411356ba906914441fc993f271a7a666d838d8de0e0900dd4a4bc12", size = 142396, upload-time = "2026-01-29T15:12:30.529Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/e3/aa1b6d3ad3cd80f10394134f73ae92a1d11fdbe974c34aa199cc18bb5fcf/orjson-3.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cded072b9f65fcfd188aead45efa5bd528ba552add619b3ad2a81f67400ec450", size = 145600, upload-time = "2026-01-29T15:12:31.848Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/cf/e4aac5a46cbd39d7e769ef8650efa851dfce22df1ba97ae2b33efe893b12/orjson-3.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ab85bdbc138e1f73a234db6bb2e4cc1f0fcec8f4bd2bd2430e957a01aadf746", size = 146967, upload-time = "2026-01-29T15:12:33.203Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0b/04/975b86a4bcf6cfeda47aad15956d52fbeda280811206e9967380fa9355c8/orjson-3.11.6-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:351b96b614e3c37a27b8ab048239ebc1e0be76cc17481a430d70a77fb95d3844", size = 421003, upload-time = "2026-01-29T15:12:35.097Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/d1/0369d0baf40eea5ff2300cebfe209883b2473ab4aa4c4974c8bd5ee42bb2/orjson-3.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f9959c85576beae5cdcaaf39510b15105f1ee8b70d5dacd90152617f57be8c83", size = 155695, upload-time = "2026-01-29T15:12:36.589Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/1f/d10c6d6ae26ff1d7c3eea6fd048280ef2e796d4fb260c5424fd021f68ecf/orjson-3.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:75682d62b1b16b61a30716d7a2ec1f4c36195de4a1c61f6665aedd947b93a5d5", size = 147392, upload-time = "2026-01-29T15:12:37.876Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/43/7479921c174441a0aa5277c313732e20713c0969ac303be9f03d88d3db5d/orjson-3.11.6-cp313-cp313-win32.whl", hash = "sha256:40dc277999c2ef227dcc13072be879b4cfd325502daeb5c35ed768f706f2bf30", size = 139718, upload-time = "2026-01-29T15:12:39.274Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/88/bc/9ffe7dfbf8454bc4e75bb8bf3a405ed9e0598df1d3535bb4adcd46be07d0/orjson-3.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:f0f6e9f8ff7905660bc3c8a54cd4a675aa98f7f175cf00a59815e2ff42c0d916", size = 136635, upload-time = "2026-01-29T15:12:40.593Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/7e/51fa90b451470447ea5023b20d83331ec741ae28d1e6d8ed547c24e7de14/orjson-3.11.6-cp313-cp313-win_arm64.whl", hash = "sha256:1608999478664de848e5900ce41f25c4ecdfc4beacbc632b6fd55e1a586e5d38", size = 135175, upload-time = "2026-01-29T15:12:41.997Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/9f/46ca908abaeeec7560638ff20276ab327b980d73b3cc2f5b205b4a1c60b3/orjson-3.11.6-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:6026db2692041d2a23fe2545606df591687787825ad5821971ef0974f2c47630", size = 249823, upload-time = "2026-01-29T15:12:43.332Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/78/ca478089818d18c9cd04f79c43f74ddd031b63c70fa2a946eb5e85414623/orjson-3.11.6-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:132b0ab2e20c73afa85cf142e547511feb3d2f5b7943468984658f3952b467d4", size = 134328, upload-time = "2026-01-29T15:12:45.171Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/5e/cbb9d830ed4e47f4375ad8eef8e4fff1bf1328437732c3809054fc4e80be/orjson-3.11.6-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b376fb05f20a96ec117d47987dd3b39265c635725bda40661b4c5b73b77b5fde", size = 137651, upload-time = "2026-01-29T15:12:46.602Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/3a/35df6558c5bc3a65ce0961aefee7f8364e59af78749fc796ea255bfa0cf5/orjson-3.11.6-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:954dae4e080574672a1dfcf2a840eddef0f27bd89b0e94903dd0824e9c1db060", size = 134596, upload-time = "2026-01-29T15:12:47.95Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cd/8e/3d32dd7b7f26a19cc4512d6ed0ae3429567c71feef720fe699ff43c5bc9e/orjson-3.11.6-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe515bb89d59e1e4b48637a964f480b35c0a2676de24e65e55310f6016cca7ce", size = 140923, upload-time = "2026-01-29T15:12:49.333Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/9c/1efbf5c99b3304f25d6f0d493a8d1492ee98693637c10ce65d57be839d7b/orjson-3.11.6-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:380f9709c275917af28feb086813923251e11ee10687257cd7f1ea188bcd4485", size = 144068, upload-time = "2026-01-29T15:12:50.927Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/83/0d19eeb5be797de217303bbb55dde58dba26f996ed905d301d98fd2d4637/orjson-3.11.6-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a8173e0d3f6081e7034c51cf984036d02f6bab2a2126de5a759d79f8e5a140e7", size = 142493, upload-time = "2026-01-29T15:12:52.432Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/a7/573fec3df4dc8fc259b7770dc6c0656f91adce6e19330c78d23f87945d1e/orjson-3.11.6-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dddf9ba706294906c56ef5150a958317b09aa3a8a48df1c52ccf22ec1907eac", size = 145616, upload-time = "2026-01-29T15:12:53.903Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/0e/23551b16f21690f7fd5122e3cf40fdca5d77052a434d0071990f97f5fe2f/orjson-3.11.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:cbae5c34588dc79938dffb0b6fbe8c531f4dc8a6ad7f39759a9eb5d2da405ef2", size = 146951, upload-time = "2026-01-29T15:12:55.698Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/63/5e6c8f39805c39123a18e412434ea364349ee0012548d08aa586e2bd6aa9/orjson-3.11.6-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:f75c318640acbddc419733b57f8a07515e587a939d8f54363654041fd1f4e465", size = 421024, upload-time = "2026-01-29T15:12:57.434Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/4d/724975cf0087f6550bd01fd62203418afc0ea33fd099aed318c5bcc52df8/orjson-3.11.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:e0ab8d13aa2a3e98b4a43487c9205b2c92c38c054b4237777484d503357c8437", size = 155774, upload-time = "2026-01-29T15:12:59.397Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a8/a3/f4c4e3f46b55db29e0a5f20493b924fc791092d9a03ff2068c9fe6c1002f/orjson-3.11.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f884c7fb1020d44612bd7ac0db0babba0e2f78b68d9a650c7959bf99c783773f", size = 147393, upload-time = "2026-01-29T15:13:00.769Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/86/6f5529dd27230966171ee126cecb237ed08e9f05f6102bfaf63e5b32277d/orjson-3.11.6-cp314-cp314-win32.whl", hash = "sha256:8d1035d1b25732ec9f971e833a3e299d2b1a330236f75e6fd945ad982c76aaf3", size = 139760, upload-time = "2026-01-29T15:13:02.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d3/b5/91ae7037b2894a6b5002fb33f4fbccec98424a928469835c3837fbb22a9b/orjson-3.11.6-cp314-cp314-win_amd64.whl", hash = "sha256:931607a8865d21682bb72de54231655c86df1870502d2962dbfd12c82890d077", size = 136633, upload-time = "2026-01-29T15:13:04.267Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/74/f473a3ec7a0a7ebc825ca8e3c86763f7d039f379860c81ba12dcdd456547/orjson-3.11.6-cp314-cp314-win_arm64.whl", hash = "sha256:fe71f6b283f4f1832204ab8235ce07adad145052614f77c876fcf0dac97bc06f", size = 135168, upload-time = "2026-01-29T15:13:05.932Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/1d/1ea6005fffb56715fd48f632611e163d1604e8316a5bad2288bee9a1c9eb/orjson-3.11.4-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5e59d23cd93ada23ec59a96f215139753fbfe3a4d989549bcb390f8c00370b39", size = 243498, upload-time = "2025-10-24T15:48:48.101Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/d7/ffed10c7da677f2a9da307d491b9eb1d0125b0307019c4ad3d665fd31f4f/orjson-3.11.4-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:5c3aedecfc1beb988c27c79d52ebefab93b6c3921dbec361167e6559aba2d36d", size = 128961, upload-time = "2025-10-24T15:48:49.571Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/96/3e4d10a18866d1368f73c8c44b7fe37cc8a15c32f2a7620be3877d4c55a3/orjson-3.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da9e5301f1c2caa2a9a4a303480d79c9ad73560b2e7761de742ab39fe59d9175", size = 130321, upload-time = "2025-10-24T15:48:50.713Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/eb/1f/465f66e93f434f968dd74d5b623eb62c657bdba2332f5a8be9f118bb74c7/orjson-3.11.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8873812c164a90a79f65368f8f96817e59e35d0cc02786a5356f0e2abed78040", size = 129207, upload-time = "2025-10-24T15:48:52.193Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/28/43/d1e94837543321c119dff277ae8e348562fe8c0fafbb648ef7cb0c67e521/orjson-3.11.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d7feb0741ebb15204e748f26c9638e6665a5fa93c37a2c73d64f1669b0ddc63", size = 136323, upload-time = "2025-10-24T15:48:54.806Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/04/93303776c8890e422a5847dd012b4853cdd88206b8bbd3edc292c90102d1/orjson-3.11.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01ee5487fefee21e6910da4c2ee9eef005bee568a0879834df86f888d2ffbdd9", size = 137440, upload-time = "2025-10-24T15:48:56.326Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/ef/75519d039e5ae6b0f34d0336854d55544ba903e21bf56c83adc51cd8bf82/orjson-3.11.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d40d46f348c0321df01507f92b95a377240c4ec31985225a6668f10e2676f9a", size = 136680, upload-time = "2025-10-24T15:48:57.476Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/18/bf8581eaae0b941b44efe14fee7b7862c3382fbc9a0842132cfc7cf5ecf4/orjson-3.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95713e5fc8af84d8edc75b785d2386f653b63d62b16d681687746734b4dfc0be", size = 136160, upload-time = "2025-10-24T15:48:59.631Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c4/35/a6d582766d351f87fc0a22ad740a641b0a8e6fc47515e8614d2e4790ae10/orjson-3.11.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ad73ede24f9083614d6c4ca9a85fe70e33be7bf047ec586ee2363bc7418fe4d7", size = 140318, upload-time = "2025-10-24T15:49:00.834Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/76/b3/5a4801803ab2e2e2d703bce1a56540d9f99a9143fbec7bf63d225044fef8/orjson-3.11.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:842289889de515421f3f224ef9c1f1efb199a32d76d8d2ca2706fa8afe749549", size = 406330, upload-time = "2025-10-24T15:49:02.327Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/55/a8f682f64833e3a649f620eafefee175cbfeb9854fc5b710b90c3bca45df/orjson-3.11.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3b2427ed5791619851c52a1261b45c233930977e7de8cf36de05636c708fa905", size = 149580, upload-time = "2025-10-24T15:49:03.517Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/e4/c132fa0c67afbb3eb88274fa98df9ac1f631a675e7877037c611805a4413/orjson-3.11.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3c36e524af1d29982e9b190573677ea02781456b2e537d5840e4538a5ec41907", size = 139846, upload-time = "2025-10-24T15:49:04.761Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/06/dc3491489efd651fef99c5908e13951abd1aead1257c67f16135f95ce209/orjson-3.11.4-cp311-cp311-win32.whl", hash = "sha256:87255b88756eab4a68ec61837ca754e5d10fa8bc47dc57f75cedfeaec358d54c", size = 135781, upload-time = "2025-10-24T15:49:05.969Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/b7/5e5e8d77bd4ea02a6ac54c42c818afb01dd31961be8a574eb79f1d2cfb1e/orjson-3.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:e2d5d5d798aba9a0e1fede8d853fa899ce2cb930ec0857365f700dffc2c7af6a", size = 131391, upload-time = "2025-10-24T15:49:07.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0f/dc/9484127cc1aa213be398ed735f5f270eedcb0c0977303a6f6ddc46b60204/orjson-3.11.4-cp311-cp311-win_arm64.whl", hash = "sha256:6bb6bb41b14c95d4f2702bce9975fda4516f1db48e500102fc4d8119032ff045", size = 126252, upload-time = "2025-10-24T15:49:08.869Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/51/6b556192a04595b93e277a9ff71cd0cc06c21a7df98bcce5963fa0f5e36f/orjson-3.11.4-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d4371de39319d05d3f482f372720b841c841b52f5385bd99c61ed69d55d9ab50", size = 243571, upload-time = "2025-10-24T15:49:10.008Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/2c/2602392ddf2601d538ff11848b98621cd465d1a1ceb9db9e8043181f2f7b/orjson-3.11.4-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:e41fd3b3cac850eaae78232f37325ed7d7436e11c471246b87b2cd294ec94853", size = 128891, upload-time = "2025-10-24T15:49:11.297Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/47/bf85dcf95f7a3a12bf223394a4f849430acd82633848d52def09fa3f46ad/orjson-3.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600e0e9ca042878c7fdf189cf1b028fe2c1418cc9195f6cb9824eb6ed99cb938", size = 130137, upload-time = "2025-10-24T15:49:12.544Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/4d/a0cb31007f3ab6f1fd2a1b17057c7c349bc2baf8921a85c0180cc7be8011/orjson-3.11.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7bbf9b333f1568ef5da42bc96e18bf30fd7f8d54e9ae066d711056add508e415", size = 129152, upload-time = "2025-10-24T15:49:13.754Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/ef/2811def7ce3d8576b19e3929fff8f8f0d44bc5eb2e0fdecb2e6e6cc6c720/orjson-3.11.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4806363144bb6e7297b8e95870e78d30a649fdc4e23fc84daa80c8ebd366ce44", size = 136834, upload-time = "2025-10-24T15:49:15.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/00/d4/9aee9e54f1809cec8ed5abd9bc31e8a9631d19460e3b8470145d25140106/orjson-3.11.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad355e8308493f527d41154e9053b86a5be892b3b359a5c6d5d95cda23601cb2", size = 137519, upload-time = "2025-10-24T15:49:16.557Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/db/ea/67bfdb5465d5679e8ae8d68c11753aaf4f47e3e7264bad66dc2f2249e643/orjson-3.11.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a7517482667fb9f0ff1b2f16fe5829296ed7a655d04d68cd9711a4d8a4e708", size = 136749, upload-time = "2025-10-24T15:49:17.796Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/7e/62517dddcfce6d53a39543cd74d0dccfcbdf53967017c58af68822100272/orjson-3.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97eb5942c7395a171cbfecc4ef6701fc3c403e762194683772df4c54cfbb2210", size = 136325, upload-time = "2025-10-24T15:49:19.347Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/18/ae/40516739f99ab4c7ec3aaa5cc242d341fcb03a45d89edeeaabc5f69cb2cf/orjson-3.11.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:149d95d5e018bdd822e3f38c103b1a7c91f88d38a88aada5c4e9b3a73a244241", size = 140204, upload-time = "2025-10-24T15:49:20.545Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/18/ff5734365623a8916e3a4037fcef1cd1782bfc14cf0992afe7940c5320bf/orjson-3.11.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:624f3951181eb46fc47dea3d221554e98784c823e7069edb5dbd0dc826ac909b", size = 406242, upload-time = "2025-10-24T15:49:21.884Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/43/96436041f0a0c8c8deca6a05ebeaf529bf1de04839f93ac5e7c479807aec/orjson-3.11.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:03bfa548cf35e3f8b3a96c4e8e41f753c686ff3d8e182ce275b1751deddab58c", size = 150013, upload-time = "2025-10-24T15:49:23.185Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/48/78302d98423ed8780479a1e682b9aecb869e8404545d999d34fa486e573e/orjson-3.11.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:525021896afef44a68148f6ed8a8bf8375553d6066c7f48537657f64823565b9", size = 139951, upload-time = "2025-10-24T15:49:24.428Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/7b/ad613fdcdaa812f075ec0875143c3d37f8654457d2af17703905425981bf/orjson-3.11.4-cp312-cp312-win32.whl", hash = "sha256:b58430396687ce0f7d9eeb3dd47761ca7d8fda8e9eb92b3077a7a353a75efefa", size = 136049, upload-time = "2025-10-24T15:49:25.973Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/3c/9cf47c3ff5f39b8350fb21ba65d789b6a1129d4cbb3033ba36c8a9023520/orjson-3.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:c6dbf422894e1e3c80a177133c0dda260f81428f9de16d61041949f6a2e5c140", size = 131461, upload-time = "2025-10-24T15:49:27.259Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c6/3b/e2425f61e5825dc5b08c2a5a2b3af387eaaca22a12b9c8c01504f8614c36/orjson-3.11.4-cp312-cp312-win_arm64.whl", hash = "sha256:d38d2bc06d6415852224fcc9c0bfa834c25431e466dc319f0edd56cca81aa96e", size = 126167, upload-time = "2025-10-24T15:49:28.511Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/15/c52aa7112006b0f3d6180386c3a46ae057f932ab3425bc6f6ac50431cca1/orjson-3.11.4-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:2d6737d0e616a6e053c8b4acc9eccea6b6cce078533666f32d140e4f85002534", size = 243525, upload-time = "2025-10-24T15:49:29.737Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ec/38/05340734c33b933fd114f161f25a04e651b0c7c33ab95e9416ade5cb44b8/orjson-3.11.4-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:afb14052690aa328cc118a8e09f07c651d301a72e44920b887c519b313d892ff", size = 128871, upload-time = "2025-10-24T15:49:31.109Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/b9/ae8d34899ff0c012039b5a7cb96a389b2476e917733294e498586b45472d/orjson-3.11.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38aa9e65c591febb1b0aed8da4d469eba239d434c218562df179885c94e1a3ad", size = 130055, upload-time = "2025-10-24T15:49:33.382Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/aa/6346dd5073730451bee3681d901e3c337e7ec17342fb79659ec9794fc023/orjson-3.11.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f2cf4dfaf9163b0728d061bebc1e08631875c51cd30bf47cb9e3293bfbd7dcd5", size = 129061, upload-time = "2025-10-24T15:49:34.935Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/e4/8eea51598f66a6c853c380979912d17ec510e8e66b280d968602e680b942/orjson-3.11.4-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:89216ff3dfdde0e4070932e126320a1752c9d9a758d6a32ec54b3b9334991a6a", size = 136541, upload-time = "2025-10-24T15:49:36.923Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/47/cb8c654fa9adcc60e99580e17c32b9e633290e6239a99efa6b885aba9dbc/orjson-3.11.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9daa26ca8e97fae0ce8aa5d80606ef8f7914e9b129b6b5df9104266f764ce436", size = 137535, upload-time = "2025-10-24T15:49:38.307Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/92/04b8cc5c2b729f3437ee013ce14a60ab3d3001465d95c184758f19362f23/orjson-3.11.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c8b2769dc31883c44a9cd126560327767f848eb95f99c36c9932f51090bfce9", size = 136703, upload-time = "2025-10-24T15:49:40.795Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/fd/d0733fcb9086b8be4ebcfcda2d0312865d17d0d9884378b7cffb29d0763f/orjson-3.11.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1469d254b9884f984026bd9b0fa5bbab477a4bfe558bba6848086f6d43eb5e73", size = 136293, upload-time = "2025-10-24T15:49:42.347Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/d7/3c5514e806837c210492d72ae30ccf050ce3f940f45bf085bab272699ef4/orjson-3.11.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:68e44722541983614e37117209a194e8c3ad07838ccb3127d96863c95ec7f1e0", size = 140131, upload-time = "2025-10-24T15:49:43.638Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/dd/ba9d32a53207babf65bd510ac4d0faaa818bd0df9a9c6f472fe7c254f2e3/orjson-3.11.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:8e7805fda9672c12be2f22ae124dcd7b03928d6c197544fe12174b86553f3196", size = 406164, upload-time = "2025-10-24T15:49:45.498Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/f9/f68ad68f4af7c7bde57cd514eaa2c785e500477a8bc8f834838eb696a685/orjson-3.11.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:04b69c14615fb4434ab867bf6f38b2d649f6f300af30a6705397e895f7aec67a", size = 149859, upload-time = "2025-10-24T15:49:46.981Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b6/d2/7f847761d0c26818395b3d6b21fb6bc2305d94612a35b0a30eae65a22728/orjson-3.11.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:639c3735b8ae7f970066930e58cf0ed39a852d417c24acd4a25fc0b3da3c39a6", size = 139926, upload-time = "2025-10-24T15:49:48.321Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/37/acd14b12dc62db9a0e1d12386271b8661faae270b22492580d5258808975/orjson-3.11.4-cp313-cp313-win32.whl", hash = "sha256:6c13879c0d2964335491463302a6ca5ad98105fc5db3565499dcb80b1b4bd839", size = 136007, upload-time = "2025-10-24T15:49:49.938Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/a9/967be009ddf0a1fffd7a67de9c36656b28c763659ef91352acc02cbe364c/orjson-3.11.4-cp313-cp313-win_amd64.whl", hash = "sha256:09bf242a4af98732db9f9a1ec57ca2604848e16f132e3f72edfd3c5c96de009a", size = 131314, upload-time = "2025-10-24T15:49:51.248Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/db/399abd6950fbd94ce125cb8cd1a968def95174792e127b0642781e040ed4/orjson-3.11.4-cp313-cp313-win_arm64.whl", hash = "sha256:a85f0adf63319d6c1ba06fb0dbf997fced64a01179cf17939a6caca662bf92de", size = 126152, upload-time = "2025-10-24T15:49:52.922Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/e3/54ff63c093cc1697e758e4fceb53164dd2661a7d1bcd522260ba09f54533/orjson-3.11.4-cp314-cp314-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:42d43a1f552be1a112af0b21c10a5f553983c2a0938d2bbb8ecd8bc9fb572803", size = 243501, upload-time = "2025-10-24T15:49:54.288Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/7d/e2d1076ed2e8e0ae9badca65bf7ef22710f93887b29eaa37f09850604e09/orjson-3.11.4-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:26a20f3fbc6c7ff2cb8e89c4c5897762c9d88cf37330c6a117312365d6781d54", size = 128862, upload-time = "2025-10-24T15:49:55.961Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/37/ca2eb40b90621faddfa9517dfe96e25f5ae4d8057a7c0cdd613c17e07b2c/orjson-3.11.4-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e3f20be9048941c7ffa8fc523ccbd17f82e24df1549d1d1fe9317712d19938e", size = 130047, upload-time = "2025-10-24T15:49:57.406Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/62/1021ed35a1f2bad9040f05fa4cc4f9893410df0ba3eaa323ccf899b1c90a/orjson-3.11.4-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aac364c758dc87a52e68e349924d7e4ded348dedff553889e4d9f22f74785316", size = 129073, upload-time = "2025-10-24T15:49:58.782Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/3f/f84d966ec2a6fd5f73b1a707e7cd876813422ae4bf9f0145c55c9c6a0f57/orjson-3.11.4-cp314-cp314-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d5c54a6d76e3d741dcc3f2707f8eeb9ba2a791d3adbf18f900219b62942803b1", size = 136597, upload-time = "2025-10-24T15:50:00.12Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/32/78/4fa0aeca65ee82bbabb49e055bd03fa4edea33f7c080c5c7b9601661ef72/orjson-3.11.4-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f28485bdca8617b79d44627f5fb04336897041dfd9fa66d383a49d09d86798bc", size = 137515, upload-time = "2025-10-24T15:50:01.57Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/9d/0c102e26e7fde40c4c98470796d050a2ec1953897e2c8ab0cb95b0759fa2/orjson-3.11.4-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bfc2a484cad3585e4ba61985a6062a4c2ed5c7925db6d39f1fa267c9d166487f", size = 136703, upload-time = "2025-10-24T15:50:02.944Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/ac/2de7188705b4cdfaf0b6c97d2f7849c17d2003232f6e70df98602173f788/orjson-3.11.4-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e34dbd508cb91c54f9c9788923daca129fe5b55c5b4eebe713bf5ed3791280cf", size = 136311, upload-time = "2025-10-24T15:50:04.441Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/52/847fcd1a98407154e944feeb12e3b4d487a0e264c40191fb44d1269cbaa1/orjson-3.11.4-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b13c478fa413d4b4ee606ec8e11c3b2e52683a640b006bb586b3041c2ca5f606", size = 140127, upload-time = "2025-10-24T15:50:07.398Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/ae/21d208f58bdb847dd4d0d9407e2929862561841baa22bdab7aea10ca088e/orjson-3.11.4-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:724ca721ecc8a831b319dcd72cfa370cc380db0bf94537f08f7edd0a7d4e1780", size = 406201, upload-time = "2025-10-24T15:50:08.796Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8d/55/0789d6de386c8366059db098a628e2ad8798069e94409b0d8935934cbcb9/orjson-3.11.4-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:977c393f2e44845ce1b540e19a786e9643221b3323dae190668a98672d43fb23", size = 149872, upload-time = "2025-10-24T15:50:10.234Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/1d/7ff81ea23310e086c17b41d78a72270d9de04481e6113dbe2ac19118f7fb/orjson-3.11.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1e539e382cf46edec157ad66b0b0872a90d829a6b71f17cb633d6c160a223155", size = 139931, upload-time = "2025-10-24T15:50:11.623Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/92/25b886252c50ed64be68c937b562b2f2333b45afe72d53d719e46a565a50/orjson-3.11.4-cp314-cp314-win32.whl", hash = "sha256:d63076d625babab9db5e7836118bdfa086e60f37d8a174194ae720161eb12394", size = 136065, upload-time = "2025-10-24T15:50:13.025Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/b8/718eecf0bb7e9d64e4956afaafd23db9f04c776d445f59fe94f54bdae8f0/orjson-3.11.4-cp314-cp314-win_amd64.whl", hash = "sha256:0a54d6635fa3aaa438ae32e8570b9f0de36f3f6562c308d2a2a452e8b0592db1", size = 131310, upload-time = "2025-10-24T15:50:14.46Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1a/bf/def5e25d4d8bfce296a9a7c8248109bf58622c21618b590678f945a2c59c/orjson-3.11.4-cp314-cp314-win_arm64.whl", hash = "sha256:78b999999039db3cf58f6d230f524f04f75f129ba3d1ca2ed121f8657e575d3d", size = 126151, upload-time = "2025-10-24T15:50:15.878Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5727,11 +5713,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.7.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -6352,16 +6338,16 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "release-tag"
|
||||
version = "0.5.2"
|
||||
version = "0.4.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/92/01192a540b29cfadaa23850c8f6a2041d541b83a3fa1dc52a5f55212b3b6/release_tag-0.5.2-py3-none-any.whl", hash = "sha256:1e9ca7618bcfc63ad7a0728c84bbad52ef82d07586c4cc11365b44ea8f588069", size = 1264752, upload-time = "2026-03-11T00:27:18.674Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4f/77/81fb42a23cd0de61caf84266f7aac1950b1c324883788b7c48e5344f61ae/release_tag-0.5.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8fbc61ff7bac2b96fab09566ec45c6508c201efc3f081f57702e1761bbc178d5", size = 1255075, upload-time = "2026-03-11T00:27:24.442Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/e6/769f8be94304529c1a531e995f2f3ac83f3c54738ce488b0abde75b20851/release_tag-0.5.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa3d7e495a0c516858a81878d03803539712677a3d6e015503de21cce19bea5e", size = 1163627, upload-time = "2026-03-11T00:27:26.412Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/68/7543e9daa0dfd41c487bf140d91fd5879327bb7c001a96aa5264667c30a1/release_tag-0.5.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:e8b60453218d6926da1fdcb99c2e17c851be0d7ab1975e97951f0bff5f32b565", size = 1140133, upload-time = "2026-03-11T00:27:20.633Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/30/9087825696271012d889d136310dbdf0811976ae2b2f5a490f4e437903e1/release_tag-0.5.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0e302ed60c2bf8b7ba5634842be28a27d83cec995869e112b0348b3f01a84ff5", size = 1264767, upload-time = "2026-03-11T00:27:28.355Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/a3/5b51b0cbdbf2299f545124beab182cfdfe01bf5b615efbc94aee3a64ea67/release_tag-0.5.2-py3-none-win_amd64.whl", hash = "sha256:e3c0629d373a16b9a3da965e89fca893640ce9878ec548865df3609b70989a89", size = 1340816, upload-time = "2026-03-11T00:27:22.622Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/6f/832c2023a8bd8414c93452bd8b43bf61cedfa5b9575f70c06fb911e51a29/release_tag-0.5.2-py3-none-win_arm64.whl", hash = "sha256:5f26b008e0be0c7a122acd8fcb1bb5c822f38e77fed0c0bf6c550cc226c6bf14", size = 1203191, upload-time = "2026-03-11T00:27:29.789Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/39/18/c1d17d973f73f0aa7e2c45f852839ab909756e1bd9727d03babe400fcef0/release_tag-0.4.3-py3-none-any.whl", hash = "sha256:4206f4fa97df930c8176bfee4d3976a7385150ed14b317bd6bae7101ac8b66dd", size = 1181112, upload-time = "2025-12-03T00:18:19.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/33/c7/ecc443953840ac313856b2181f55eb8d34fa2c733cdd1edd0bcceee0938d/release_tag-0.4.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7a347a9ad3d2af16e5367e52b451fbc88a0b7b666850758e8f9a601554a8fb13", size = 1170517, upload-time = "2025-12-03T00:18:11.663Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/81/2f6ffa0d87c792364ca9958433fe088c8acc3d096ac9734040049c6ad506/release_tag-0.4.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2d1603aa37d8e4f5df63676bbfddc802fbc108a744ba28288ad25c997981c164", size = 1101663, upload-time = "2025-12-03T00:18:15.173Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/ed/9e4ebe400fc52e38dda6e6a45d9da9decd4535ab15e170b8d9b229a66730/release_tag-0.4.3-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6db7b81a198e3ba6a87496a554684912c13f9297ea8db8600a80f4f971709d37", size = 1079322, upload-time = "2025-12-03T00:18:16.094Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/64/9e0ce6119e091ef9211fa82b9593f564eeec8bdd86eff6a97fe6e2fcb20f/release_tag-0.4.3-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:d79a9cf191dd2c29e1b3a35453fa364b08a7aadd15aeb2c556a7661c6cf4d5ad", size = 1181129, upload-time = "2025-12-03T00:18:15.82Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/09/d96acf18f0773b6355080a568ba48931faa9dbe91ab1abefc6f8c4df04a8/release_tag-0.4.3-py3-none-win_amd64.whl", hash = "sha256:3958b880375f2241d0cc2b9882363bf54b1d4d7ca8ffc6eecc63ab92f23307f0", size = 1260773, upload-time = "2025-12-03T00:18:14.723Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/51/da/ecb6346df1ffb0752fe213e25062f802c10df2948717f0d5f9816c2df914/release_tag-0.4.3-py3-none-win_arm64.whl", hash = "sha256:7d5b08000e6e398d46f05a50139031046348fba6d47909f01e468bb7600c19df", size = 1142155, upload-time = "2025-12-03T00:18:20.647Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7247,19 +7233,21 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.5.5"
|
||||
version = "6.5.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -53,8 +53,6 @@ const sharedConfig = {
|
||||
// Testing & Mocking
|
||||
"msw",
|
||||
"until-async",
|
||||
// Language Detection
|
||||
"linguist-languages",
|
||||
// Markdown & Syntax Highlighting
|
||||
"react-markdown",
|
||||
"remark-.*", // All remark packages
|
||||
@@ -145,9 +143,7 @@ module.exports = {
|
||||
"**/src/app/**/utils/*.test.ts",
|
||||
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
|
||||
"**/src/refresh-components/**/*.test.ts",
|
||||
"**/src/refresh-pages/**/*.test.ts",
|
||||
"**/src/sections/**/*.test.ts",
|
||||
"**/src/components/**/*.test.ts",
|
||||
// Add more patterns here as you add more unit tests
|
||||
],
|
||||
},
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { OpenButton } from "@opal/components";
|
||||
import { Disabled as DisabledProvider } from "@opal/core";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
@@ -33,9 +32,16 @@ export const WithIcon: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Selected: Story = {
|
||||
args: {
|
||||
selected: true,
|
||||
children: "Selected",
|
||||
},
|
||||
};
|
||||
|
||||
export const Open: Story = {
|
||||
args: {
|
||||
interaction: "hover",
|
||||
transient: true,
|
||||
children: "Open state",
|
||||
},
|
||||
};
|
||||
@@ -47,27 +53,18 @@ export const Disabled: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const Foldable: Story = {
|
||||
export const LightProminence: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
prominence: "light",
|
||||
children: "Light prominence",
|
||||
},
|
||||
};
|
||||
|
||||
export const FoldableDisabled: Story = {
|
||||
export const HeavyProminence: Story = {
|
||||
args: {
|
||||
foldable: true,
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
prominence: "heavy",
|
||||
children: "Heavy prominence",
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<DisabledProvider disabled>
|
||||
<Story />
|
||||
</DisabledProvider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export const Sizes: Story = {
|
||||
@@ -81,12 +78,3 @@ export const Sizes: Story = {
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
export const WithTooltip: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
tooltip: "Open settings",
|
||||
tooltipSide: "bottom",
|
||||
},
|
||||
};
|
||||
@@ -17,9 +17,7 @@ OpenButton is a **tighter, specialized use-case** of SelectButton:
|
||||
- It hardcodes `variant="select-heavy"` (SelectButton exposes `variant`)
|
||||
- It adds a built-in chevron with CSS-driven rotation (SelectButton has no chevron)
|
||||
- It auto-detects Radix `data-state="open"` to derive `interaction` (SelectButton has no Radix awareness)
|
||||
- It does not support `rightIcon` (SelectButton does)
|
||||
|
||||
Both components support `foldable` using the same pattern: `interactive-foldable-host` class + `Interactive.Foldable` wrapper around the label and trailing icon. When foldable, the left icon stays visible while the rest collapses. If you change the foldable implementation in one, update the other to match.
|
||||
- It does not support `foldable` or `rightIcon` (SelectButton does)
|
||||
|
||||
If you need a general-purpose stateful toggle, use `SelectButton`. If you need a popover/dropdown trigger with a chevron, use `OpenButton`.
|
||||
|
||||
@@ -28,12 +26,10 @@ If you need a general-purpose stateful toggle, use `SelectButton`. If you need a
|
||||
```
|
||||
Interactive.Stateful <- variant="select-heavy", interaction, state, disabled, onClick
|
||||
└─ Interactive.Container <- height, rounding, padding (from `size`)
|
||||
└─ div.opal-button.interactive-foreground [.interactive-foldable-host]
|
||||
└─ div.opal-button.interactive-foreground
|
||||
├─ div > Icon? (interactive-foreground-icon)
|
||||
├─ [Foldable]? (wraps label + chevron when foldable)
|
||||
│ ├─ <span>? .opal-button-label
|
||||
│ └─ div > ChevronIcon .opal-open-button-chevron
|
||||
└─ <span>? / ChevronIcon (non-foldable)
|
||||
├─ <span>? .opal-button-label
|
||||
└─ div > ChevronIcon .opal-open-button-chevron (interactive-foreground-icon)
|
||||
```
|
||||
|
||||
- **`interaction` controls both the chevron and the hover visual state.** When `interaction` is `"hover"` (explicitly or via Radix `data-state="open"`), the chevron rotates 180° and the hover background activates.
|
||||
@@ -48,7 +44,6 @@ Interactive.Stateful <- variant="select-heavy", interaction, state, di
|
||||
| `interaction` | `"rest" \| "hover" \| "active"` | auto | JS-controlled interaction override. Falls back to Radix `data-state="open"` when omitted. |
|
||||
| `icon` | `IconFunctionComponent` | — | Left icon component |
|
||||
| `children` | `string` | — | Content between icon and chevron |
|
||||
| `foldable` | `boolean` | `false` | When `true`, requires both `icon` and `children`; the left icon stays visible while the label + chevron collapse when not hovered. If `tooltip` is omitted on a disabled foldable button, the label text is used as the tooltip. |
|
||||
| `size` | `SizeVariant` | `"lg"` | Size preset controlling height, rounding, and padding |
|
||||
| `width` | `WidthVariant` | — | Width preset |
|
||||
| `tooltip` | `string` | — | Tooltip text shown on hover |
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user