mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 14:45:46 +00:00
Compare commits
46 Commits
litellm_pr
...
v2.12.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
af0721e063 | ||
|
|
567651a812 | ||
|
|
7589767bb9 | ||
|
|
589d613f1e | ||
|
|
b17d7e0033 | ||
|
|
131d418771 | ||
|
|
0be04391b3 | ||
|
|
20351d9998 | ||
|
|
22152ad871 | ||
|
|
7caf197f98 | ||
|
|
140bc82b36 | ||
|
|
e7ecbfafd1 | ||
|
|
2c2af369f5 | ||
|
|
2032b76fbf | ||
|
|
055b30b00e | ||
|
|
360a4cf591 | ||
|
|
3d3cab9f91 | ||
|
|
6120d012ba | ||
|
|
3e7e2e93f2 | ||
|
|
ccf482fa3b | ||
|
|
fd45a612da | ||
|
|
c444d8883b | ||
|
|
9947837f9f | ||
|
|
bc324a8070 | ||
|
|
26f648c24a | ||
|
|
638f20f5f3 | ||
|
|
f6ee57f523 | ||
|
|
aae6fc7aac | ||
|
|
5d7a664250 | ||
|
|
e7386490bf | ||
|
|
106e10a143 | ||
|
|
513f430a1b | ||
|
|
696d73822f | ||
|
|
bfcc5a20a2 | ||
|
|
efe3613354 | ||
|
|
62405bdc42 | ||
|
|
8f505dc45f | ||
|
|
75f0db4fe5 | ||
|
|
f0a5c579a3 | ||
|
|
293bf30847 | ||
|
|
8774ca3b0f | ||
|
|
016a73f85f | ||
|
|
2eddb4e23e | ||
|
|
0a61660a59 | ||
|
|
a10599e76e | ||
|
|
b3d3f7af76 |
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"Playwright": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"@playwright/mcp"
|
||||
]
|
||||
},
|
||||
"Linear": {
|
||||
"url": "https://mcp.linear.app/mcp"
|
||||
},
|
||||
"Figma": {
|
||||
"url": "https://mcp.figma.com/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
4
.github/workflows/deployment.yml
vendored
4
.github/workflows/deployment.yml
vendored
@@ -91,8 +91,8 @@ jobs:
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Only build desktop for semver tags (excluding beta)
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] && [[ "$IS_BETA" != "true" ]]; then
|
||||
# Skip desktop builds on beta tags and nightly runs
|
||||
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
151
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,151 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
3
.github/workflows/pr-helm-chart-testing.yml
vendored
3
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,8 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
137
.github/workflows/pr-integration-tests.yml
vendored
137
.github/workflows/pr-integration-tests.yml
vendored
@@ -56,7 +56,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -448,139 +448,6 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "API server is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
@@ -720,7 +587,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
272
.github/workflows/pr-playwright-tests.yml
vendored
272
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,6 +22,9 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -52,9 +55,6 @@ env:
|
||||
MCP_SERVER_PUBLIC_HOST: host.docker.internal
|
||||
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
|
||||
|
||||
# Visual regression S3 bucket (shared across all jobs)
|
||||
PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts
|
||||
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on:
|
||||
@@ -242,9 +242,6 @@ jobs:
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests (${{ matrix.project }})
|
||||
permissions:
|
||||
id-token: write # Required for OIDC-based AWS credential exchange (S3 access)
|
||||
contents: read
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=8cpu-linux-arm64
|
||||
@@ -434,6 +431,8 @@ jobs:
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
@@ -441,134 +440,9 @@ jobs:
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
path: ./web/test-results/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload screenshots
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshots/
|
||||
retention-days: 30
|
||||
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Determine baseline revision
|
||||
if: always()
|
||||
id: baseline-rev
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }}
|
||||
GH_REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ "${EVENT_NAME}" = "pull_request" ]; then
|
||||
# PRs compare against the base branch (e.g. main, release/2.5)
|
||||
echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT"
|
||||
elif [ "${EVENT_NAME}" = "merge_group" ]; then
|
||||
# Merge queue compares against the target branch (e.g. refs/heads/main -> main)
|
||||
echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT"
|
||||
elif [[ "${GH_REF}" == refs/tags/* ]]; then
|
||||
# Tag builds compare against the tag name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
# Push builds (main, release/*) compare against the branch name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Generate screenshot diff report
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff compare \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}"
|
||||
|
||||
- name: Upload visual diff report to S3
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
echo "No summary file found — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
if [ "${HAS_DIFF}" != "true" ]; then
|
||||
echo "No visual differences for ${PROJECT} — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \
|
||||
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
|
||||
|
||||
- name: Upload visual diff summary
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-summary-${{ matrix.project }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json
|
||||
if-no-files-found: ignore
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload visual diff report artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/
|
||||
if-no-files-found: ignore
|
||||
retention-days: 30
|
||||
|
||||
- name: Update S3 baselines
|
||||
if: >-
|
||||
success() && (
|
||||
github.ref == 'refs/heads/main' ||
|
||||
startsWith(github.ref, 'refs/heads/release/') ||
|
||||
startsWith(github.ref, 'refs/tags/v') ||
|
||||
(
|
||||
github.event_name == 'merge_group' && (
|
||||
github.event.merge_group.base_ref == 'refs/heads/main' ||
|
||||
startsWith(github.event.merge_group.base_ref, 'refs/heads/release/')
|
||||
)
|
||||
)
|
||||
)
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}" \
|
||||
--delete
|
||||
else
|
||||
echo "No screenshots to upload for ${PROJECT} — skipping baseline update."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
@@ -586,95 +460,6 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
- name: Post combined PR comment
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
REPO: ${{ github.repository }}
|
||||
S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
run: |
|
||||
MARKER="<!-- visual-regression-report -->"
|
||||
|
||||
# Build the markdown table from all summary files
|
||||
TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |"
|
||||
TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|"
|
||||
TABLE_ROWS=""
|
||||
HAS_ANY_SUMMARY=false
|
||||
|
||||
for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do
|
||||
SUMMARY_FILE="${SUMMARY_DIR}summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
HAS_ANY_SUMMARY=true
|
||||
PROJECT=$(jq -r '.project' "${SUMMARY_FILE}")
|
||||
CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}")
|
||||
ADDED=$(jq -r '.added' "${SUMMARY_FILE}")
|
||||
REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}")
|
||||
UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}")
|
||||
TOTAL=$(jq -r '.total' "${SUMMARY_FILE}")
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
|
||||
if [ "${TOTAL}" = "0" ]; then
|
||||
REPORT_LINK="_No screenshots_"
|
||||
elif [ "${HAS_DIFF}" = "true" ]; then
|
||||
REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html"
|
||||
REPORT_LINK="[View Report](${REPORT_URL})"
|
||||
else
|
||||
REPORT_LINK="✅ No changes"
|
||||
fi
|
||||
|
||||
TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n"
|
||||
done
|
||||
|
||||
if [ "${HAS_ANY_SUMMARY}" = "false" ]; then
|
||||
echo "No visual diff summaries found — skipping PR comment."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BODY=$(printf '%s\n' \
|
||||
"${MARKER}" \
|
||||
"### 🖼️ Visual Regression Report" \
|
||||
"" \
|
||||
"${TABLE_HEADER}" \
|
||||
"${TABLE_DIVIDER}" \
|
||||
"$(printf '%b' "${TABLE_ROWS}")")
|
||||
|
||||
# Upsert: find existing comment with the marker, or create a new one
|
||||
EXISTING_COMMENT_ID=$(gh api \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
--jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \
|
||||
2>/dev/null | head -1)
|
||||
|
||||
if [ -n "${EXISTING_COMMENT_ID}" ]; then
|
||||
gh api \
|
||||
--method PATCH \
|
||||
"repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \
|
||||
-f body="${BODY}"
|
||||
else
|
||||
gh api \
|
||||
--method POST \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
-f body="${BODY}"
|
||||
fi
|
||||
|
||||
playwright-required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
@@ -685,3 +470,48 @@ jobs:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
|
||||
73
.github/workflows/preview.yml
vendored
73
.github/workflows/preview.yml
vendored
@@ -1,73 +0,0 @@
|
||||
name: Preview Deployment
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
on:
|
||||
push:
|
||||
branches-ignore:
|
||||
- main
|
||||
paths:
|
||||
- "web/**"
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
jobs:
|
||||
Deploy-Preview:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Pull Vercel Environment Information
|
||||
run: npx --yes ${{ env.VERCEL_CLI }} pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
- name: Build Project Artifacts
|
||||
run: npx --yes ${{ env.VERCEL_CLI }} build --token=${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
- name: Deploy Project Artifacts to Vercel
|
||||
id: deploy
|
||||
run: |
|
||||
DEPLOYMENT_URL=$(npx --yes ${{ env.VERCEL_CLI }} deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }})
|
||||
echo "url=$DEPLOYMENT_URL" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Update PR comment with deployment URL
|
||||
if: always() && steps.deploy.outputs.url
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
DEPLOYMENT_URL: ${{ steps.deploy.outputs.url }}
|
||||
run: |
|
||||
# Find the PR for this branch
|
||||
PR_NUMBER=$(gh pr list --head "$GITHUB_REF_NAME" --json number --jq '.[0].number')
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "No open PR found for branch $GITHUB_REF_NAME, skipping comment."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
COMMENT_MARKER="<!-- preview-deployment -->"
|
||||
COMMENT_BODY="$COMMENT_MARKER
|
||||
**Preview Deployment**
|
||||
|
||||
| Status | Preview | Commit | Updated |
|
||||
| --- | --- | --- | --- |
|
||||
| ✅ | $DEPLOYMENT_URL | \`${GITHUB_SHA::7}\` | $(date -u '+%Y-%m-%d %H:%M:%S UTC') |"
|
||||
|
||||
# Find existing comment by marker
|
||||
EXISTING_COMMENT_ID=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \
|
||||
--jq ".[] | select(.body | startswith(\"$COMMENT_MARKER\")) | .id" | head -1)
|
||||
|
||||
if [ -n "$EXISTING_COMMENT_ID" ]; then
|
||||
gh api "repos/$GITHUB_REPOSITORY/issues/comments/$EXISTING_COMMENT_ID" \
|
||||
--method PATCH --field body="$COMMENT_BODY"
|
||||
else
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT_BODY"
|
||||
fi
|
||||
290
.github/workflows/sandbox-deployment.yml
vendored
290
.github/workflows/sandbox-deployment.yml
vendored
@@ -1,290 +0,0 @@
|
||||
name: Build and Push Sandbox Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "experimental-cc4a.*"
|
||||
|
||||
# Restrictive defaults; jobs declare what they need.
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
check-sandbox-changes:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
sandbox-changed: ${{ steps.check.outputs.sandbox-changed }}
|
||||
new-version: ${{ steps.version.outputs.new-version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for sandbox-relevant file changes
|
||||
id: check
|
||||
run: |
|
||||
# Get the previous tag to diff against
|
||||
CURRENT_TAG="${GITHUB_REF_NAME}"
|
||||
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1)
|
||||
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
echo "No previous experimental-cc4a tag found, building unconditionally"
|
||||
echo "sandbox-changed=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}"
|
||||
|
||||
# Check if any sandbox-relevant files changed
|
||||
SANDBOX_PATHS=(
|
||||
"backend/onyx/server/features/build/sandbox/"
|
||||
)
|
||||
|
||||
CHANGED=false
|
||||
for path in "${SANDBOX_PATHS[@]}"; do
|
||||
if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then
|
||||
echo "Changes detected in: $path"
|
||||
CHANGED=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine new sandbox version
|
||||
id: version
|
||||
if: steps.check.outputs.sandbox-changed == 'true'
|
||||
run: |
|
||||
# Query Docker Hub for the latest versioned tag
|
||||
LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \
|
||||
| jq -r '.results[].name' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| sort -V \
|
||||
| tail -n 1)
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No existing version tags found on Docker Hub, starting at 0.1.1"
|
||||
NEW_VERSION="0.1.1"
|
||||
else
|
||||
CURRENT_VERSION="${LATEST_TAG#v}"
|
||||
echo "Latest version on Docker Hub: $CURRENT_VERSION"
|
||||
|
||||
# Increment patch version
|
||||
MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3)
|
||||
NEW_PATCH=$((PATCH + 1))
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}"
|
||||
fi
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-sandbox-amd64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-sandbox-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
build-sandbox-arm64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-sandbox-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-sandbox:
|
||||
needs:
|
||||
- check-sandbox-changes
|
||||
- build-sandbox-amd64
|
||||
- build-sandbox-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-sandbox
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 30
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }}
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,7 +6,6 @@
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -246,7 +246,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup,opensearch_migration"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -275,7 +275,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
|
||||
@@ -144,10 +144,6 @@ function.
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
**Task Time Limits**:
|
||||
Since all tasks are executed in thread pools, the time limit features of Celery are silently
|
||||
disabled and won't work. Timeout logic must be implemented within the task itself.
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -474,7 +474,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
logger.debug("run_migrations_online starting (pytest-alembic mode).")
|
||||
logger.info("run_migrations_online starting (pytest-alembic mode).")
|
||||
|
||||
# For pytest-alembic, we use the default schema (public)
|
||||
schema_name = context.config.attributes.get(
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add default_app_mode to user
|
||||
|
||||
Revision ID: 114a638452db
|
||||
Revises: feead2911109
|
||||
Create Date: 2026-02-09 18:57:08.274640
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "114a638452db"
|
||||
down_revision = "feead2911109"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"default_app_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="CHAT",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "default_app_mode")
|
||||
@@ -11,6 +11,7 @@ import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from httpx import HTTPStatusError
|
||||
import httpx
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
@@ -518,11 +519,15 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
def upgrade() -> None:
|
||||
if SKIP_CANON_DRIVE_IDS:
|
||||
return
|
||||
current_search_settings, _ = active_search_settings()
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
future_search_settings,
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add_user_preferences
|
||||
|
||||
Revision ID: 175ea04c7087
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-04 18:16:24.830873
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "175ea04c7087"
|
||||
down_revision = "d56ffa94ca32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("user_preferences", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "user_preferences")
|
||||
@@ -16,6 +16,7 @@ from typing import Generator
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -125,11 +126,14 @@ def remove_old_tags() -> None:
|
||||
the document got reindexed, the old tag would not be removed.
|
||||
This function removes those old tags by comparing it against the tags in vespa.
|
||||
"""
|
||||
current_search_settings, _ = active_search_settings()
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings, future_search_settings
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
"""add chunk error and vespa count columns to opensearch tenant migration
|
||||
|
||||
Revision ID: 93c15d6a6fbb
|
||||
Revises: d3fd499c829c
|
||||
Create Date: 2026-02-11 23:07:34.576725
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93c15d6a6fbb"
|
||||
down_revision = "d3fd499c829c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_errored",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa")
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored")
|
||||
@@ -1,81 +0,0 @@
|
||||
"""seed_memory_tool and add enable_memory_tool to user
|
||||
|
||||
Revision ID: b51c6844d1df
|
||||
Revises: 93c15d6a6fbb
|
||||
Create Date: 2026-02-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b51c6844d1df"
|
||||
down_revision = "93c15d6a6fbb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MEMORY_TOOL = {
|
||||
"name": "MemoryTool",
|
||||
"display_name": "Add Memory",
|
||||
"description": "Save memories about the user for future conversations.",
|
||||
"in_code_tool_id": "MemoryTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
existing = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id"
|
||||
),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"enable_memory_tool",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "enable_memory_tool")
|
||||
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -1,102 +0,0 @@
|
||||
"""add_file_reader_tool
|
||||
|
||||
Revision ID: d3fd499c829c
|
||||
Revises: 114a638452db
|
||||
Create Date: 2026-02-07 19:28:22.452337
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d3fd499c829c"
|
||||
down_revision = "114a638452db"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
FILE_READER_TOOL = {
|
||||
"name": "read_file",
|
||||
"display_name": "File Reader",
|
||||
"description": (
|
||||
"Read sections of user-uploaded files by character offset. "
|
||||
"Useful for inspecting large files that cannot fit entirely in context."
|
||||
),
|
||||
"in_code_tool_id": "FileReaderTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = existing[0]
|
||||
else:
|
||||
# Insert new tool
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = result.scalar_one()
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
|
||||
|
||||
# Remove persona associations first (FK constraint)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE tool_id IN (
|
||||
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
@@ -1,69 +0,0 @@
|
||||
"""add_opensearch_tenant_migration_columns
|
||||
|
||||
Revision ID: feead2911109
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-10 17:46:34.029937
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "feead2911109"
|
||||
down_revision = "175ea04c7087"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column("vespa_visit_continuation_token", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_migrated",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"migration_completed_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"enable_opensearch_retrieval",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "enable_opensearch_retrieval")
|
||||
op.drop_column("opensearch_tenant_migration_record", "migration_completed_at")
|
||||
op.drop_column("opensearch_tenant_migration_record", "created_at")
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_migrated")
|
||||
op.drop_column(
|
||||
"opensearch_tenant_migration_record", "vespa_visit_continuation_token"
|
||||
)
|
||||
@@ -1,20 +1,20 @@
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.light import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -263,9 +263,15 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
source=source,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -50,12 +50,7 @@ def github_doc_sync(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
@@ -65,7 +60,21 @@ def github_doc_sync(
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = github_connector.fetch_configured_repos()
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
|
||||
@@ -18,12 +18,7 @@ def github_group_sync(
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
|
||||
@@ -50,12 +50,7 @@ def gmail_doc_sync(
|
||||
already populated.
|
||||
"""
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
gmail_connector.load_credentials(credential_json)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
|
||||
@@ -295,12 +295,7 @@ def gdrive_doc_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
|
||||
@@ -391,12 +391,7 @@ def gdrive_group_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
admin_service = get_admin_service(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
@@ -24,12 +24,7 @@ def jira_doc_sync(
|
||||
jira_connector = JiraConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_connector.load_credentials(credential_json)
|
||||
jira_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -119,13 +119,8 @@ def jira_group_sync(
|
||||
if not jira_base_url:
|
||||
raise ValueError("No jira_base_url found in connector config")
|
||||
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_client = build_jira_client(
|
||||
credentials=credential_json,
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
jira_base=jira_base_url,
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
@@ -30,11 +30,7 @@ def get_any_salesforce_client_for_doc_id(
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = (
|
||||
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if first_cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
@@ -162,11 +158,7 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
|
||||
@@ -24,12 +24,7 @@ def sharepoint_doc_sync(
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
sharepoint_connector.load_credentials(credential_json)
|
||||
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -25,12 +25,7 @@ def sharepoint_group_sync(
|
||||
|
||||
# Create SharePoint connector instance and load credentials
|
||||
connector = SharepointConnector(**connector_config)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
connector.load_credentials(credential_json)
|
||||
connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
if not connector.msal_app:
|
||||
raise RuntimeError("MSAL app not initialized in connector")
|
||||
|
||||
@@ -151,14 +151,9 @@ def slack_doc_sync(
|
||||
tenant_id = get_current_tenant_id()
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
credential_json["slack_bot_token"],
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -63,14 +63,9 @@ def slack_group_sync(
|
||||
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
credential_json["slack_bot_token"],
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -25,12 +25,7 @@ def teams_doc_sync(
|
||||
teams_connector = TeamsConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
teams_connector.load_credentials(credential_json)
|
||||
teams_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -77,7 +77,7 @@ def stream_search_query(
|
||||
# Get document index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
|
||||
@@ -270,11 +270,7 @@ def confluence_oauth_accessible_resources(
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
@@ -341,12 +337,7 @@ def confluence_oauth_finalize(
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
existing_credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credential_json: dict[str, Any] = dict(existing_credential_json)
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -67,11 +66,7 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_vector_db)],
|
||||
)
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(current_user),
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,6 +44,14 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -82,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -90,7 +113,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license found.
|
||||
# No license found in cache or DB.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.db.models import OAuthUserToken
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -34,10 +33,7 @@ class OAuthTokenManager:
|
||||
if not user_token:
|
||||
return None
|
||||
|
||||
if not user_token.token_data:
|
||||
return None
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
token_data = user_token.token_data
|
||||
|
||||
# Check if token is expired
|
||||
if OAuthTokenManager.is_token_expired(token_data):
|
||||
@@ -55,10 +51,7 @@ class OAuthTokenManager:
|
||||
|
||||
def refresh_token(self, user_token: OAuthUserToken) -> str:
|
||||
"""Refresh access token using refresh token"""
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
token_data = user_token.token_data
|
||||
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
@@ -160,11 +153,3 @@ class OAuthTokenManager:
|
||||
separator = "&" if "?" in oauth_config.authorization_url else "?"
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(token_data, SensitiveValue):
|
||||
return token_data.get_value(apply_mask=False)
|
||||
return token_data
|
||||
|
||||
@@ -1459,7 +1459,6 @@ def get_anonymous_user() -> User:
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -526,12 +525,6 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
@@ -573,31 +566,3 @@ class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
def get_bootsteps() -> list[type]:
|
||||
return [LivenessProbe]
|
||||
|
||||
|
||||
# Task modules that require a vector DB (Vespa/OpenSearch).
|
||||
# When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists.
|
||||
_VECTOR_DB_TASK_MODULES: set[str] = {
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
# EE modules that are vector-DB-dependent
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
}
|
||||
# NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set
|
||||
# above. It contains celery_beat_heartbeat (which only writes to Redis) alongside
|
||||
# document cleanup tasks. The cleanup tasks won't be invoked in minimal mode
|
||||
# because the periodic tasks that trigger them are in other filtered modules.
|
||||
|
||||
|
||||
def filter_task_modules(modules: list[str]) -> list[str]:
|
||||
"""Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True."""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return modules
|
||||
return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES]
|
||||
|
||||
@@ -118,25 +118,23 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -96,9 +96,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -96,12 +96,10 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -110,16 +110,13 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -94,9 +94,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -314,18 +314,17 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -107,9 +107,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
@@ -216,39 +215,36 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
"task": OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
"name": "check-for-documents-for-opensearch-migration",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB.
|
||||
_VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-indexing",
|
||||
"check-for-connector-deletion",
|
||||
"check-for-vespa-sync",
|
||||
"check-for-pruning",
|
||||
"check-for-hierarchy-fetching",
|
||||
"check-for-checkpoint-cleanup",
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
beat_task_templates = [
|
||||
t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES
|
||||
]
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-documents-from-vespa-to-opensearch",
|
||||
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
|
||||
# value and the task is maximally busy, we can expect to see some
|
||||
# enqueued tasks be revoked over time. This is ok; by erring on the
|
||||
# side of "there will probably always be at least one task of this
|
||||
# type in the queue", we are minimizing this task's idleness while
|
||||
# still giving chances for other tasks to execute.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -2,12 +2,27 @@
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -27,32 +42,225 @@ from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
|
||||
from onyx.db.opensearch_migration import get_vespa_visit_state
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.opensearch_migration import create_opensearch_migration_records_with_commit
|
||||
from onyx.db.opensearch_migration import get_last_opensearch_migration_document_id
|
||||
from onyx.db.opensearch_migration import (
|
||||
mark_migration_completed_time_if_not_set_with_commit,
|
||||
get_opensearch_migration_records_needing_migration,
|
||||
)
|
||||
from onyx.db.opensearch_migration import get_paginated_document_batch
|
||||
from onyx.db.opensearch_migration import (
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import (
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import should_document_migration_be_permanently_failed
|
||||
from onyx.db.opensearch_migration import (
|
||||
try_insert_opensearch_tenant_migration_record_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import update_vespa_visit_progress_with_commit
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
def _migrate_single_document(
|
||||
document_id: str,
|
||||
opensearch_document_index: OpenSearchDocumentIndex,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
tenant_state: TenantState,
|
||||
) -> int:
|
||||
"""Migrates a single document from Vespa to OpenSearch.
|
||||
|
||||
Args:
|
||||
document_id: The ID of the document to migrate.
|
||||
opensearch_document_index: The OpenSearch document index to use.
|
||||
vespa_document_index: The Vespa document index to use.
|
||||
tenant_state: The tenant state to use.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no chunks are found for the document in Vespa, or if
|
||||
the number of candidate chunks to migrate does not match the number
|
||||
of chunks in Vespa.
|
||||
|
||||
Returns:
|
||||
The number of chunks migrated.
|
||||
"""
|
||||
vespa_document_chunks: list[dict[str, Any]] = (
|
||||
vespa_document_index.get_raw_document_chunks(document_id=document_id)
|
||||
)
|
||||
if not vespa_document_chunks:
|
||||
raise RuntimeError(f"No chunks found for document {document_id} in Vespa.")
|
||||
|
||||
opensearch_document_chunks: list[DocumentChunk] = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
vespa_document_chunks, tenant_state, document_id
|
||||
)
|
||||
)
|
||||
if len(opensearch_document_chunks) != len(vespa_document_chunks):
|
||||
raise RuntimeError(
|
||||
f"Bug: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does not match "
|
||||
f"number of chunks in Vespa ({len(vespa_document_chunks)})."
|
||||
)
|
||||
|
||||
opensearch_document_index.index_raw_chunks(chunks=opensearch_document_chunks)
|
||||
|
||||
return len(opensearch_document_chunks)
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
soft_time_limit=CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
time_limit=CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
def check_for_documents_for_opensearch_migration_task(
|
||||
self: Task, *, tenant_id: str # noqa: ARG001
|
||||
) -> bool | None:
|
||||
"""
|
||||
Periodic task to check for and add documents to the OpenSearch migration
|
||||
table.
|
||||
|
||||
Should not execute meaningful logic at the same time as
|
||||
migrate_documents_from_vespa_to_opensearch_task.
|
||||
|
||||
Effectively tries to populate as many migration records as possible within
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of
|
||||
1000 documents.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task failed.
|
||||
"""
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping check for documents for the OpenSearch migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Checking for documents for OpenSearch migration.")
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# migrate_documents_from_vespa_to_opensearch_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
# released after this time.
|
||||
timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
# .acquire will block until the lock is acquired.
|
||||
blocking=True,
|
||||
# Time to wait to acquire the lock.
|
||||
blocking_timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
if not lock.acquire():
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration check task timed out waiting for the lock."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Acquired the OpenSearch migration check lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
num_documents_found_for_record_creation = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
err_str = (
|
||||
f"Tenant ID mismatch in the OpenSearch migration check task: "
|
||||
f"{tenant_id} != {get_current_tenant_id()}. This should never happen."
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
while (
|
||||
time.monotonic() - task_start_time
|
||||
< CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# For pagination, get the last ID we've inserted into
|
||||
# OpenSearchMigration.
|
||||
last_opensearch_migration_document_id = (
|
||||
get_last_opensearch_migration_document_id(db_session)
|
||||
)
|
||||
# Now get the next batch of doc IDs starting after the last ID.
|
||||
# We'll do 1000 documents per transaction/timeout check.
|
||||
document_ids = get_paginated_document_batch(
|
||||
db_session,
|
||||
limit=1000,
|
||||
prev_ending_document_id=last_opensearch_migration_document_id,
|
||||
)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No more documents to insert for OpenSearch migration."
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and the
|
||||
# number of documents matches the number of migration
|
||||
# records, we can be done with this task and update
|
||||
# document_migration_record_table_population_status.
|
||||
return True
|
||||
|
||||
# Create the migration records for the next batch of documents
|
||||
# with status PENDING.
|
||||
create_opensearch_migration_records_with_commit(
|
||||
db_session, document_ids
|
||||
)
|
||||
num_documents_found_for_record_creation += len(document_ids)
|
||||
|
||||
# Try to create the singleton row in
|
||||
# OpenSearchTenantMigrationRecord if it doesn't already exist.
|
||||
# This is a reasonable place to put it because we already have a
|
||||
# lock, a session, and error handling, at the cost of running
|
||||
# this small set of logic for every batch.
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
except Exception:
|
||||
task_logger.exception("Error in the OpenSearch migration check task.")
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the check task."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Finished checking for documents for OpenSearch migration. Found {num_documents_found_for_record_creation} documents "
|
||||
f"to create migration records for in {time.monotonic() - task_start_time:.3f} seconds. However, this may include "
|
||||
"documents for which there already exist records."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
@@ -66,21 +274,18 @@ GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
def migrate_documents_from_vespa_to_opensearch_task(
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
tenant_id: str,
|
||||
) -> bool | None:
|
||||
"""
|
||||
Periodic task to migrate chunks from Vespa to OpenSearch via the Visit API.
|
||||
"""Periodic task to migrate documents from Vespa to OpenSearch.
|
||||
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
Should not execute meaningful logic at the same time as
|
||||
check_for_documents_for_opensearch_migration_task.
|
||||
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
Effectively tries to migrate as many documents as possible within
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of 5 documents.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -89,13 +294,16 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
"""
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping chunk migration task."
|
||||
"OpenSearch migration is not enabled, skipping trying to migrate documents from Vespa to OpenSearch."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
|
||||
task_logger.info("Trying a migration batch from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# check_for_documents_for_opensearch_migration_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
@@ -117,8 +325,9 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
num_documents_migrated = 0
|
||||
num_chunks_migrated = 0
|
||||
num_documents_failed = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
@@ -128,100 +337,97 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
# We reconstruct this mapping for every task invocation because a
|
||||
# document may have been added in the time between two tasks.
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# We'll do 5 documents per transaction/timeout check.
|
||||
records_needing_migration = (
|
||||
get_opensearch_migration_records_needing_migration(
|
||||
db_session, limit=5
|
||||
)
|
||||
)
|
||||
if not records_needing_migration:
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
"No documents found that need to be migrated from Vespa to OpenSearch."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token: {continuation_token}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token=continuation_token,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and
|
||||
# document_migration_record_table_population_status is done, we
|
||||
# can be done with this task and update
|
||||
# overall_document_migration_status accordingly. Note that this
|
||||
# includes marking connectors as needing reindexing if some
|
||||
# migrations failed.
|
||||
return True
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=tenant_id, multitenant=MULTI_TENANT
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
tenant_state,
|
||||
sanitized_to_original_doc_id_mapping,
|
||||
)
|
||||
)
|
||||
if len(opensearch_document_chunks) != len(raw_vespa_chunks):
|
||||
task_logger.error(
|
||||
f"Migration task error: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does "
|
||||
f"not match number of chunks in Vespa ({len(raw_vespa_chunks)}). {len(errored_chunks)} chunks "
|
||||
"errored."
|
||||
)
|
||||
for record in records_needing_migration:
|
||||
try:
|
||||
# If the Document's chunk count is not known, it was
|
||||
# probably just indexed so fail here to give it a chance to
|
||||
# sync. If in the rare event this Document has not been
|
||||
# re-indexed in a very long time and is still under the
|
||||
# "old" embedding/indexing logic where chunk count was never
|
||||
# stored, we will eventually permanently fail and thus force
|
||||
# a re-index of this doc, which is a desireable outcome.
|
||||
if record.document.chunk_count is None:
|
||||
raise RuntimeError(
|
||||
f"Document {record.document_id} has no chunk count."
|
||||
)
|
||||
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
opensearch_document_index.index_raw_chunks(
|
||||
chunks=opensearch_document_chunks
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Indexed {len(opensearch_document_chunks)} chunks into OpenSearch in "
|
||||
f"{time.monotonic() - index_opensearch_chunks_start_time:.3f} seconds."
|
||||
)
|
||||
chunks_migrated = _migrate_single_document(
|
||||
document_id=record.document_id,
|
||||
opensearch_document_index=opensearch_document_index,
|
||||
vespa_document_index=vespa_document_index,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token=next_continuation_token,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
)
|
||||
# If the number of chunks in Vespa is not in sync with the
|
||||
# Document table for this doc let's not consider this
|
||||
# completed and let's let a subsequent run take care of it.
|
||||
if chunks_migrated != record.document.chunk_count:
|
||||
raise RuntimeError(
|
||||
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks "
|
||||
f"in Vespa ({record.document.chunk_count}) for document {record.document_id}."
|
||||
)
|
||||
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
|
||||
num_documents_migrated += 1
|
||||
num_chunks_migrated += chunks_migrated
|
||||
except Exception:
|
||||
record.status = OpenSearchDocumentMigrationStatus.FAILED
|
||||
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
|
||||
task_logger.exception(
|
||||
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
|
||||
)
|
||||
num_documents_failed += 1
|
||||
finally:
|
||||
record.attempts_count += 1
|
||||
record.last_attempt_at = datetime.now(timezone.utc)
|
||||
if should_document_migration_be_permanently_failed(record):
|
||||
record.status = (
|
||||
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
)
|
||||
# TODO(andrei): Not necessarily here but if this happens
|
||||
# we'll need to mark the connector as needing reindex.
|
||||
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
return False
|
||||
finally:
|
||||
@@ -233,11 +439,9 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"OpenSearch chunk migration task pausing (time limit reached). "
|
||||
f"Total chunks migrated this task: {total_chunks_migrated_this_task}. "
|
||||
f"Total chunks errored this task: {total_chunks_errored_this_task}. "
|
||||
f"Elapsed: {time.monotonic() - task_start_time:.3f}s. "
|
||||
"Will resume from continuation token on next invocation."
|
||||
f"Finished a migration batch from Vespa to OpenSearch. Migrated {num_chunks_migrated} chunks "
|
||||
f"from {num_documents_migrated} documents in {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Failed to migrate {num_documents_failed} documents."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -141,7 +140,9 @@ def _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_acl: dict[str, int] | None,
|
||||
) -> tuple[bool, list[str]]:
|
||||
if not vespa_acl:
|
||||
return False, []
|
||||
raise ValueError(
|
||||
"Missing ACL in Vespa chunk. This does not make sense as it implies the document is never searchable by anyone ever."
|
||||
)
|
||||
acl_list = list(vespa_acl.keys())
|
||||
is_public = PUBLIC_DOC_PAT in acl_list
|
||||
if is_public:
|
||||
@@ -152,163 +153,133 @@ def _transform_vespa_acl_to_opensearch_acl(
|
||||
def transform_vespa_chunks_to_opensearch_chunks(
|
||||
vespa_chunks: list[dict[str, Any]],
|
||||
tenant_state: TenantState,
|
||||
sanitized_to_original_doc_id_mapping: dict[str, str],
|
||||
) -> tuple[list[DocumentChunk], list[dict[str, Any]]]:
|
||||
document_id: str,
|
||||
) -> list[DocumentChunk]:
|
||||
result: list[DocumentChunk] = []
|
||||
errored_chunks: list[dict[str, Any]] = []
|
||||
for vespa_chunk in vespa_chunks:
|
||||
try:
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using
|
||||
# replace_invalid_doc_id_characters. This was a poor design choice
|
||||
# and we don't want this in OpenSearch; whatever restrictions there
|
||||
# may be on indexed chunk ID should have no bearing on the chunk's
|
||||
# document ID field, even if document ID is an argument to the chunk
|
||||
# ID. Deliberately choose to use the real doc ID supplied to this
|
||||
# function.
|
||||
if vespa_document_id in sanitized_to_original_doc_id_mapping:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa document ID {vespa_document_id} does not match the document ID supplied "
|
||||
f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. "
|
||||
"The Vespa ID will be discarded."
|
||||
)
|
||||
document_id = sanitized_to_original_doc_id_mapping.get(
|
||||
vespa_document_id, vespa_document_id
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using replace_invalid_doc_id_characters.
|
||||
# This was a poor design choice and we don't want this in OpenSearch;
|
||||
# whatever restrictions there may be on indexed chunk ID should have no
|
||||
# bearing on the chunk's document ID field, even if document ID is an
|
||||
# argument to the chunk ID. Deliberately choose to use the real doc ID
|
||||
# supplied to this function.
|
||||
if vespa_document_id != document_id:
|
||||
logger.warning(
|
||||
f"Vespa document ID {vespa_document_id} does not match the document ID supplied {document_id}. "
|
||||
"The Vespa ID will be discarded."
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
raise ValueError("Missing content in Vespa chunk.")
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(vespa_chunk[EMBEDDINGS])
|
||||
if not content_vector:
|
||||
raise ValueError("Missing content_vector in Vespa chunk.")
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError("Missing source_type in Vespa chunk.")
|
||||
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError("Missing semantic_identifier in Vespa chunk.")
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
# This should exist; fail loudly if it does not; this function will
|
||||
# raise in that event.
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
raise ValueError(
|
||||
f"Missing content in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
)
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(
|
||||
vespa_chunk[EMBEDDINGS]
|
||||
)
|
||||
if not content_vector:
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
raise ValueError(
|
||||
f"Missing content_vector in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError(
|
||||
f"Missing source_type in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
)
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
result.append(opensearch_chunk)
|
||||
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError(
|
||||
f"Missing semantic_identifier in Vespa chunk with document ID {vespa_document_id} and chunk "
|
||||
f"index {chunk_index}."
|
||||
)
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
if not is_public and not acl_list:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index} has no "
|
||||
"public ACL and no access control list. This does not make sense as it implies the document is never "
|
||||
"searchable. Continuing with the migration..."
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
raise ValueError(
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
)
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
raise ValueError(
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
)
|
||||
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
result.append(opensearch_chunk)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.exception(
|
||||
f"Migration error: Error transforming Vespa chunk with document ID {vespa_chunk.get(DOCUMENT_ID)} "
|
||||
f"and chunk index {vespa_chunk.get(CHUNK_ID)} into an OpenSearch chunk. Continuing with "
|
||||
"the migration..."
|
||||
)
|
||||
errored_chunks.append(vespa_chunk)
|
||||
|
||||
return result, errored_chunks
|
||||
return result
|
||||
|
||||
@@ -10,23 +10,24 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -39,7 +40,6 @@ from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -57,6 +57,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -120,7 +131,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -135,7 +163,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -148,12 +190,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -161,137 +226,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _process_user_file_without_vector_db(
|
||||
uf: UserFile,
|
||||
documents: list[Document],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file when the vector DB is disabled.
|
||||
|
||||
Extracts raw text and computes a token count, stores the plaintext in
|
||||
the file store, and marks the file as COMPLETED. Skips embedding and
|
||||
the indexing pipeline entirely.
|
||||
"""
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_tokenizer_encode_func
|
||||
|
||||
# Combine section text from all document sections
|
||||
combined_text = " ".join(
|
||||
section.text for doc in documents for section in doc.sections if section.text
|
||||
)
|
||||
|
||||
# Compute token count using the user's default LLM tokenizer
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
encode = get_llm_tokenizer_encode_func(llm)
|
||||
token_count: int | None = len(encode(combined_text))
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Failed to compute token count for {uf.id}, falling back to None"
|
||||
)
|
||||
token_count = None
|
||||
|
||||
# Persist plaintext for fast FileReaderTool loads
|
||||
store_user_file_plaintext(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content=combined_text,
|
||||
)
|
||||
|
||||
# Update the DB record
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.COMPLETED
|
||||
uf.token_count = token_count
|
||||
uf.chunk_count = 0 # no chunks without vector DB
|
||||
uf.last_project_sync_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Completed id={uf.id} tokens={token_count}"
|
||||
)
|
||||
|
||||
|
||||
def _process_user_file_with_indexing(
|
||||
uf: UserFile,
|
||||
user_file_id: str,
|
||||
documents: list[Document],
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file through the full indexing pipeline (vector DB path)."""
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
current_search_settings = next(
|
||||
(ss for ss in search_settings_list if ss.status.is_current()),
|
||||
None,
|
||||
)
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
@@ -304,6 +244,12 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
@@ -334,34 +280,97 @@ def process_single_user_file(
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
file_names=[uf.name] if uf.name else None,
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({})
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
current_search_settings = next(
|
||||
(
|
||||
search_settings_instance
|
||||
for search_settings_instance in search_settings_list
|
||||
if search_settings_instance.status.is_current()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set up indexing pipeline components
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# update the doument id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
_process_user_file_without_vector_db(
|
||||
uf=uf,
|
||||
documents=documents,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
_process_user_file_with_indexing(
|
||||
uf=uf,
|
||||
user_file_id=user_file_id,
|
||||
documents=documents,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
# Re-fetch to avoid mypy error
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
and current_user_file.status != UserFileStatus.DELETING
|
||||
):
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
@@ -475,6 +484,28 @@ def process_single_user_file_delete(
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -482,43 +513,22 @@ def process_single_user_file_delete(
|
||||
)
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
# 1) Delete Vespa chunks for the document
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -630,6 +640,27 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -637,35 +668,15 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
@@ -677,6 +677,7 @@ def connector_document_extraction(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# cc4a
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
# File system only - write directly to persistent storage,
|
||||
# skip chunking/embedding/Vespa but still track documents in DB
|
||||
@@ -816,19 +817,17 @@ def connector_document_extraction(
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id:
|
||||
source_value = db_connector.source.value
|
||||
app.send_task(
|
||||
OnyxCeleryTask.SANDBOX_FILE_SYNC,
|
||||
kwargs={
|
||||
"user_id": str(creator_id),
|
||||
"tenant_id": tenant_id,
|
||||
"source": source_value,
|
||||
},
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"source={source_value} after indexing complete"
|
||||
f"after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,10 +8,8 @@ from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -504,22 +502,15 @@ def convert_chat_history(
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
) -> ChatHistoryResult:
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Convert ChatMessage history to ChatMessageSimple format.
|
||||
|
||||
For user messages: includes attached files (images attached to message, text files as separate messages)
|
||||
For assistant messages with tool calls: creates ONE ASSISTANT message with tool_calls array,
|
||||
followed by N TOOL_CALL_RESPONSE messages (OpenAI parallel tool calling format)
|
||||
For assistant messages without tool calls: creates a simple ASSISTANT message
|
||||
|
||||
Every injected text-file message is tagged with ``file_id`` and its
|
||||
metadata is collected in ``ChatHistoryResult.all_injected_file_metadata``.
|
||||
After context-window truncation, callers compare surviving ``file_id`` tags
|
||||
against this map to discover "forgotten" files and provide their metadata
|
||||
to the FileReaderTool.
|
||||
"""
|
||||
simple_messages: list[ChatMessageSimple] = []
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
|
||||
# Create a mapping of file IDs to loaded files for quick lookup
|
||||
file_map = {str(f.file_id): f for f in files}
|
||||
@@ -548,9 +539,7 @@ def convert_chat_history(
|
||||
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
|
||||
text_files.append(loaded_file)
|
||||
|
||||
# Add text files as separate messages before the user message.
|
||||
# Each message is tagged with ``file_id`` so that forgotten files
|
||||
# can be detected after context-window truncation.
|
||||
# Add text files as separate messages before the user message
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
@@ -565,14 +554,8 @@ def convert_chat_history(
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
file_id=text_file.file_id,
|
||||
)
|
||||
)
|
||||
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
|
||||
file_id=text_file.file_id,
|
||||
filename=filename or "unknown",
|
||||
approx_char_count=len(file_text),
|
||||
)
|
||||
|
||||
# Sum token counts from image files (excluding project image files)
|
||||
image_token_count = (
|
||||
@@ -681,10 +664,7 @@ def convert_chat_history(
|
||||
f"Invalid message type when constructing simple history: {chat_message.message_type}"
|
||||
)
|
||||
|
||||
return ChatHistoryResult(
|
||||
simple_messages=simple_messages,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
return simple_messages
|
||||
|
||||
|
||||
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,7 +14,6 @@ from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
@@ -30,9 +27,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -49,14 +43,12 @@ from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -209,35 +201,6 @@ def _build_project_file_citation_mapping(
|
||||
return citation_mapping
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def construct_message_history(
|
||||
system_prompt: ChatMessageSimple | None,
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
@@ -246,8 +209,6 @@ def construct_message_history(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
if last_n_user_messages is not None:
|
||||
if last_n_user_messages <= 0:
|
||||
@@ -255,17 +216,13 @@ def construct_message_history(
|
||||
"filtering chat history by last N user messages must be a value greater than 0"
|
||||
)
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
history_token_budget -= system_prompt.token_count if system_prompt else 0
|
||||
history_token_budget -= (
|
||||
custom_agent_prompt.token_count if custom_agent_prompt else 0
|
||||
)
|
||||
history_token_budget -= project_messages_tokens
|
||||
if project_files:
|
||||
history_token_budget -= project_files.total_token_count
|
||||
history_token_budget -= reminder_message.token_count if reminder_message else 0
|
||||
|
||||
if history_token_budget < 0:
|
||||
@@ -279,7 +236,11 @@ def construct_message_history(
|
||||
result = [system_prompt] if system_prompt else []
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
result.extend(project_messages)
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
result.append(project_message)
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
return result
|
||||
@@ -338,11 +299,8 @@ def construct_message_history(
|
||||
# Calculate remaining budget for history before the last user message
|
||||
remaining_budget = history_token_budget - required_tokens
|
||||
|
||||
# Truncate history_before_last_user from the top to fit in remaining budget.
|
||||
# Track dropped file messages so we can provide their metadata to the
|
||||
# FileReaderTool instead.
|
||||
# Truncate history_before_last_user from the top to fit in remaining budget
|
||||
truncated_history_before: list[ChatMessageSimple] = []
|
||||
dropped_file_ids: list[str] = []
|
||||
current_token_count = 0
|
||||
|
||||
for msg in reversed(history_before_last_user):
|
||||
@@ -351,67 +309,9 @@ def construct_message_history(
|
||||
truncated_history_before.insert(0, msg)
|
||||
current_token_count += msg.token_count
|
||||
else:
|
||||
# Can't fit this message, stop truncating.
|
||||
# This message and everything older is dropped.
|
||||
# Can't fit this message, stop truncating
|
||||
break
|
||||
|
||||
# Collect file_ids from ALL dropped messages (those not in
|
||||
# truncated_history_before). The truncation loop above keeps the most
|
||||
# recent messages, so the dropped ones are at the start of the original
|
||||
# list up to (len(history) - len(kept)).
|
||||
num_kept = len(truncated_history_before)
|
||||
for msg in history_before_last_user[: len(history_before_last_user) - num_kept]:
|
||||
if msg.file_id is not None:
|
||||
dropped_file_ids.append(msg.file_id)
|
||||
|
||||
# Also treat "orphaned" metadata entries as dropped -- these are files
|
||||
# from messages removed by summary truncation (before convert_chat_history
|
||||
# ran), so no ChatMessageSimple was ever tagged with their file_id.
|
||||
if all_injected_file_metadata:
|
||||
surviving_file_ids = {
|
||||
msg.file_id for msg in simple_chat_history if msg.file_id is not None
|
||||
}
|
||||
for fid in all_injected_file_metadata:
|
||||
if fid not in surviving_file_ids and fid not in dropped_file_ids:
|
||||
dropped_file_ids.append(fid)
|
||||
|
||||
# Build a forgotten-files metadata message if any file messages were
|
||||
# dropped AND we have metadata for them (meaning the FileReaderTool is
|
||||
# available). Reserve tokens for this message in the budget.
|
||||
forgotten_files_message: ChatMessageSimple | None = None
|
||||
if dropped_file_ids and all_injected_file_metadata and token_counter:
|
||||
forgotten_meta = [
|
||||
all_injected_file_metadata[fid]
|
||||
for fid in dropped_file_ids
|
||||
if fid in all_injected_file_metadata
|
||||
]
|
||||
if forgotten_meta:
|
||||
logger.debug(
|
||||
f"FileReader: building forgotten-files message for "
|
||||
f"{[(m.file_id, m.filename) for m in forgotten_meta]}"
|
||||
)
|
||||
forgotten_files_message = _create_file_tool_metadata_message(
|
||||
forgotten_meta, token_counter
|
||||
)
|
||||
# Shrink the remaining budget. If the metadata message doesn't
|
||||
# fit we may need to drop more history messages.
|
||||
remaining_budget -= forgotten_files_message.token_count
|
||||
while truncated_history_before and current_token_count > remaining_budget:
|
||||
evicted = truncated_history_before.pop(0)
|
||||
current_token_count -= evicted.token_count
|
||||
# If the evicted message is itself a file, add it to the
|
||||
# forgotten metadata (it's now dropped too).
|
||||
if (
|
||||
evicted.file_id is not None
|
||||
and evicted.file_id in all_injected_file_metadata
|
||||
and evicted.file_id not in {m.file_id for m in forgotten_meta}
|
||||
):
|
||||
forgotten_meta.append(all_injected_file_metadata[evicted.file_id])
|
||||
# Rebuild the message with the new entry
|
||||
forgotten_files_message = _create_file_tool_metadata_message(
|
||||
forgotten_meta, token_counter
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
@@ -424,7 +324,7 @@ def construct_message_history(
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
# [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
# 1. Add truncated history before last user message
|
||||
@@ -434,52 +334,26 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
# 3. Add project files message (inserted before last user message)
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
result.append(project_message)
|
||||
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
# 4. Add last user message (with project images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
# 5. Add messages after last user message (tool calls, responses, etc.)
|
||||
result.extend(messages_after_last_user)
|
||||
|
||||
# 7. Add reminder message at the very end
|
||||
# 6. Add reminder message at the very end
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_file_tool_metadata_message(
|
||||
file_metadata: list[FileToolMetadata],
|
||||
token_counter: Callable[[str], int],
|
||||
) -> ChatMessageSimple:
|
||||
"""Build a lightweight metadata-only message listing files available via FileReaderTool.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM must use ``read_file`` to inspect them.
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=token_counter(message_content),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
@@ -519,7 +393,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
persona: Persona | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
memories: list[str] | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
@@ -527,8 +401,6 @@ def run_llm_loop(
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
include_citations: bool = True,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
inject_memories_in_prompt: bool = True,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
@@ -636,13 +508,10 @@ def run_llm_loop(
|
||||
llm.config.model_name
|
||||
)
|
||||
|
||||
prompt_memory_context = (
|
||||
user_memory_context if inject_memories_in_prompt else None
|
||||
)
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
user_memory_context=prompt_memory_context,
|
||||
memories=memories,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
@@ -699,7 +568,7 @@ def run_llm_loop(
|
||||
ChatMessageSimple(
|
||||
message=reminder_message_text,
|
||||
token_count=token_counter(reminder_message_text),
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
if reminder_message_text
|
||||
else None
|
||||
@@ -712,8 +581,6 @@ def run_llm_loop(
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
@@ -798,14 +665,13 @@ def run_llm_loop(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
user_memory_context=user_memory_context,
|
||||
memories=memories,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
inject_memories_in_prompt=inject_memories_in_prompt,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
@@ -870,44 +736,11 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
else "add"
|
||||
)
|
||||
memory_snapshot = MemoryToolResponseSnapshot(
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
operation=operation,
|
||||
memory_id=persisted_memory_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
)
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
saved_response = tool_response.llm_facing_response
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
|
||||
@@ -36,8 +36,6 @@ from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -334,48 +332,26 @@ def extract_tool_calls_from_response_text(
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if not matched_tool_call:
|
||||
continue
|
||||
|
||||
# `find_all_json_objects` can return both an outer tool-call object and
|
||||
# its nested arguments object. If both resolve to the same tool call,
|
||||
# drop only this nested duplicate artifact.
|
||||
if (
|
||||
prev_json_obj is not None
|
||||
and prev_tool_call is not None
|
||||
and matched_tool_call == prev_tool_call
|
||||
and _is_nested_arguments_duplicate(
|
||||
previous_json_obj=prev_json_obj,
|
||||
current_json_obj=json_obj,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
matched_tool_calls.append(matched_tool_call)
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
@@ -457,42 +433,6 @@ def _try_match_json_to_tool(
|
||||
return None
|
||||
|
||||
|
||||
def _is_nested_arguments_duplicate(
|
||||
previous_json_obj: dict[str, Any],
|
||||
current_json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> bool:
|
||||
"""Detect when current object is the nested args object from previous tool call."""
|
||||
extracted_args = _extract_nested_arguments_obj(previous_json_obj, tool_name_to_def)
|
||||
return extracted_args is not None and current_json_obj == extracted_args
|
||||
|
||||
|
||||
def _extract_nested_arguments_obj(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> dict[str, Any] | None:
|
||||
# Format 1: {"name": "...", "arguments": {...}} or {"name": "...", "parameters": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
args_obj = json_obj.get("arguments", json_obj.get("parameters"))
|
||||
if isinstance(args_obj, dict):
|
||||
return args_obj
|
||||
|
||||
# Format 2: {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
function_obj = json_obj["function"]
|
||||
if "name" in function_obj and function_obj["name"] in tool_name_to_def:
|
||||
args_obj = function_obj.get("arguments", function_obj.get("parameters"))
|
||||
if isinstance(args_obj, dict):
|
||||
return args_obj
|
||||
|
||||
# Format 3: {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj and isinstance(json_obj[tool_name], dict):
|
||||
return json_obj[tool_name]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -511,7 +451,6 @@ def translate_history_to_llm_format(
|
||||
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.USER_REMINDER,
|
||||
MessageType.ASSISTANT,
|
||||
MessageType.TOOL_CALL_RESPONSE,
|
||||
]:
|
||||
@@ -573,16 +512,6 @@ def translate_history_to_llm_format(
|
||||
)
|
||||
messages.append(user_msg_text)
|
||||
|
||||
elif msg.message_type == MessageType.USER_REMINDER:
|
||||
# User reminder messages are wrapped with system-reminder tags
|
||||
# and converted to UserMessage (LLM APIs don't have a native reminder type)
|
||||
wrapped_content = f"{SYSTEM_REMINDER_TAG_OPEN}\n{msg.message}\n{SYSTEM_REMINDER_TAG_CLOSE}"
|
||||
reminder_msg = UserMessage(
|
||||
role="user",
|
||||
content=wrapped_content,
|
||||
)
|
||||
messages.append(reminder_msg)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
|
||||
@@ -244,9 +244,6 @@ class ChatMessageSimple(BaseModel):
|
||||
# represents the end of the cacheable prefix
|
||||
# used for prompt caching
|
||||
should_cache: bool = False
|
||||
# When this message represents an injected text file, this is the file's ID.
|
||||
# Used to detect which file messages survive context-window truncation.
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
@@ -257,33 +254,6 @@ class ProjectFileMetadata(BaseModel):
|
||||
file_content: str
|
||||
|
||||
|
||||
class FileToolMetadata(BaseModel):
|
||||
"""Lightweight metadata for exposing files to the FileReaderTool.
|
||||
|
||||
Used when files cannot be loaded directly into context (project too large
|
||||
or persona-attached user_files without direct-load path). The LLM receives
|
||||
a listing of these so it knows which files it can read via ``read_file``.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
approx_char_count: int
|
||||
|
||||
|
||||
class ChatHistoryResult(BaseModel):
|
||||
"""Result of converting chat history to simple format.
|
||||
|
||||
Bundles the simple messages with metadata for every text file that was
|
||||
injected into the history. After context-window truncation drops older
|
||||
messages, callers compare surviving ``file_id`` tags against this map
|
||||
to discover "forgotten" files whose metadata should be provided to the
|
||||
FileReaderTool.
|
||||
"""
|
||||
|
||||
simple_messages: list[ChatMessageSimple]
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
@@ -293,9 +263,6 @@ class ExtractedProjectFiles(BaseModel):
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections.abc import Callable
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -36,7 +35,6 @@ from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
@@ -46,7 +44,6 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -63,7 +60,6 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
@@ -94,11 +90,7 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import FileReaderToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
|
||||
FileReaderTool,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -108,53 +100,6 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class _AvailableFiles(BaseModel):
|
||||
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
|
||||
|
||||
# IDs from the ``user_file`` table (project / persona-attached files).
|
||||
user_file_ids: list[UUID] = []
|
||||
# IDs from the ``file_record`` table (chat-attached files).
|
||||
chat_file_ids: list[UUID] = []
|
||||
|
||||
|
||||
def _collect_available_file_ids(
|
||||
chat_history: list[ChatMessage],
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> _AvailableFiles:
|
||||
"""Collect all file IDs the FileReaderTool should be allowed to access.
|
||||
|
||||
Returns *separate* lists for chat-attached files (``file_record`` IDs) and
|
||||
project/user files (``user_file`` IDs) so the tool can pick the right
|
||||
loader without a try/except fallback."""
|
||||
chat_file_ids: set[UUID] = set()
|
||||
user_file_ids: set[UUID] = set()
|
||||
|
||||
for msg in chat_history:
|
||||
if not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
try:
|
||||
chat_file_ids.add(UUID(fd["id"]))
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
user_file_ids=list(user_file_ids),
|
||||
chat_file_ids=list(chat_file_ids),
|
||||
)
|
||||
|
||||
|
||||
def _should_enable_slack_search(
|
||||
persona: Persona,
|
||||
filters: BaseFilters | None,
|
||||
@@ -287,24 +232,6 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
@@ -317,49 +244,6 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata from a list of UserFile records."""
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in user_files
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
@@ -446,10 +330,12 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
if new_msg_req.mock_llm_response is not None and not INTEGRATION_TESTS_MODE:
|
||||
raise ValueError(
|
||||
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
|
||||
)
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
raise ValueError(
|
||||
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
|
||||
)
|
||||
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
|
||||
|
||||
try:
|
||||
if not new_msg_req.chat_session_id:
|
||||
@@ -577,57 +463,24 @@ def handle_stream_message_objects(
|
||||
|
||||
chat_history.append(user_message)
|
||||
|
||||
# Collect file IDs for the file reader tool *before* summary
|
||||
# truncation so that files attached to older (summarized-away)
|
||||
# messages are still accessible via the FileReaderTool.
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Find applicable summary for the current branch
|
||||
# Summary applies if its parent_message_id is in current chat_history
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
# Collect file metadata from messages that will be dropped by
|
||||
# summary truncation. These become "pre-summarized" file metadata
|
||||
# so the forgotten-file mechanism can still tell the LLM about them.
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
# We don't know the exact size without loading the
|
||||
# file, but 0 signals "unknown" to the LLM.
|
||||
approx_char_count=0,
|
||||
)
|
||||
# Filter chat_history to only messages after the cutoff
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
memories = get_memories(user, db_session)
|
||||
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# When use_memories is disabled, don't inject memories into the prompt
|
||||
# or count them in token reservation, but still pass the full context
|
||||
# to the LLM loop for memory tool persistence.
|
||||
prompt_memory_context = user_memory_context if user.use_memories else None
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=prompt_memory_context,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
@@ -639,16 +492,6 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
@@ -675,13 +518,6 @@ def handle_stream_message_objects(
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -708,10 +544,6 @@ def handle_stream_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
)
|
||||
@@ -741,12 +573,9 @@ def handle_stream_message_objects(
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Check whether the FileReaderTool is among the constructed tools.
|
||||
has_file_reader_tool = any(isinstance(t, FileReaderTool) for t in tools)
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
chat_history_result = convert_chat_history(
|
||||
simple_chat_history = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
@@ -754,32 +583,6 @@ def handle_stream_message_objects(
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
# Metadata for every text file injected into the history. After
|
||||
# context-window truncation drops older messages, the LLM loop
|
||||
# compares surviving file_id tags against this map to discover
|
||||
# "forgotten" files and provide their metadata to FileReaderTool.
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
|
||||
# Merge in file metadata from messages dropped by summary
|
||||
# truncation. These files are no longer in simple_chat_history
|
||||
# so they would otherwise be invisible to the forgotten-file
|
||||
# mechanism. They will always appear as "forgotten" since no
|
||||
# surviving message carries their file_id tag.
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if all_injected_file_metadata:
|
||||
logger.debug(
|
||||
"FileReader: file metadata for LLM: "
|
||||
f"{[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
|
||||
)
|
||||
|
||||
# Prepend summary message if compression exists
|
||||
if summary_message is not None:
|
||||
@@ -823,11 +626,6 @@ def handle_stream_message_objects(
|
||||
processing_start_time=processing_start_time,
|
||||
)
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
# Set this right before launching the LLM loop so run_in_background copies the right context.
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -856,7 +654,6 @@ def handle_stream_message_objects(
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
@@ -870,7 +667,7 @@ def handle_stream_message_objects(
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
user_memory_context=user_memory_context,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
@@ -878,8 +675,6 @@ def handle_stream_message_objects(
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -4,7 +4,6 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -13,6 +12,7 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
@@ -25,7 +25,6 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -53,7 +52,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: str,
|
||||
token_counter: Callable[[str], int],
|
||||
files: list[FileDescriptor] | None = None,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
memories: list[str] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate reserved token count for system prompt and user files.
|
||||
@@ -67,7 +66,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: Custom agent system prompt (can be empty string)
|
||||
token_counter: Function that counts tokens in text
|
||||
files: List of file descriptors from the chat message (optional)
|
||||
user_memory_context: User memory context (optional)
|
||||
memories: List of memory strings (optional)
|
||||
|
||||
Returns:
|
||||
Total reserved token count
|
||||
@@ -78,7 +77,7 @@ def calculate_reserved_tokens(
|
||||
fake_system_prompt = build_system_prompt(
|
||||
base_system_prompt=base_system_prompt,
|
||||
datetime_aware=True,
|
||||
user_memory_context=user_memory_context,
|
||||
memories=memories,
|
||||
tools=None,
|
||||
should_cite_documents=True,
|
||||
include_all_guidance=True,
|
||||
@@ -134,7 +133,7 @@ def build_reminder_message(
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
memories: list[str] | None = None,
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
@@ -158,15 +157,14 @@ def build_system_prompt(
|
||||
)
|
||||
|
||||
company_context = get_company_context()
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
)
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context or memories:
|
||||
system_prompt += USER_INFO_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
if memories:
|
||||
system_prompt += "\n".join(
|
||||
"- " + memory.strip() for memory in memories if memory.strip()
|
||||
)
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -50,17 +50,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
# Controls whether users can use User Knowledge (personal documents) in assistants
|
||||
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
|
||||
|
||||
# Disables vector DB (Vespa/OpenSearch) entirely. When True, connectors and RAG search
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -236,32 +225,11 @@ DOCUMENT_INDEX_NAME = "danswer_index"
|
||||
# OpenSearch Configs
|
||||
OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
# TODO(andrei): 60 seconds is too much, we're just setting a high default
|
||||
# timeout for now to examine why queries are slow.
|
||||
# NOTE: This timeout applies to all requests the client makes, including bulk
|
||||
# indexing.
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S") or 60
|
||||
)
|
||||
# TODO(andrei): 50 seconds is too much, we're just setting a high default
|
||||
# timeout for now to examine why queries are slow.
|
||||
# NOTE: To get useful partial results, this value should be less than the client
|
||||
# timeout above.
|
||||
DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
# Profiling adds some overhead to OpenSearch operations. This overhead is
|
||||
# unknown right now. It is enabled by default so we can get useful logs for
|
||||
# investigating slow queries. We may never disable it if the overhead is
|
||||
# minimal.
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
|
||||
@@ -102,6 +102,7 @@ DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_PENDING_USERS_KEY = "PENDING_USERS"
|
||||
@@ -157,9 +158,18 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
@@ -227,9 +237,6 @@ class DocumentSource(str, Enum):
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
# Special case for user files
|
||||
USER_FILE = "user_file"
|
||||
# Raw files for Craft sandbox access (xlsx, pptx, docx, etc.)
|
||||
# Uses RAW_BINARY processing mode - no text extraction
|
||||
CRAFT_FILE = "craft_file"
|
||||
|
||||
|
||||
class FederatedConnectorSource(str, Enum):
|
||||
@@ -311,7 +318,6 @@ class MessageType(str, Enum):
|
||||
USER = "user" # HumanMessage
|
||||
ASSISTANT = "assistant" # AIMessage - Can include tool_calls field for parallel tool calling
|
||||
TOOL_CALL_RESPONSE = "tool_call_response"
|
||||
USER_REMINDER = "user_reminder" # Custom Onyx message type which is translated into a USER message when passed to the LLM
|
||||
|
||||
|
||||
class ChatMessageSimpleType(str, Enum):
|
||||
@@ -336,7 +342,6 @@ class FileOrigin(str, Enum):
|
||||
CHAT_UPLOAD = "chat_upload"
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
CONNECTOR_METADATA = "connector_metadata"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
@@ -402,8 +407,6 @@ class OnyxCeleryQueues:
|
||||
# Sandbox processing queue
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
OPENSEARCH_MIGRATION = "opensearch_migration"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -443,6 +446,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
@@ -455,9 +461,6 @@ class OnyxRedisLocks:
|
||||
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
|
||||
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
|
||||
|
||||
# Sandbox file sync
|
||||
SANDBOX_FILE_SYNC_LOCK_PREFIX = "da_lock:sandbox_file_sync"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -588,9 +591,6 @@ class OnyxCeleryTask:
|
||||
MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_documents_from_vespa_to_opensearch_task"
|
||||
)
|
||||
MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_chunks_from_vespa_to_opensearch_task"
|
||||
)
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
|
||||
@@ -65,9 +65,7 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
if credential.credential_json is None:
|
||||
return {}
|
||||
return credential.credential_json.get_value(apply_mask=False)
|
||||
return credential.credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
@@ -83,7 +81,7 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -171,7 +171,6 @@ def process_onyx_metadata(
|
||||
|
||||
return (
|
||||
OnyxMetadata(
|
||||
document_id=metadata.get("id"),
|
||||
source_type=source_type,
|
||||
link=metadata.get("link"),
|
||||
file_display_name=metadata.get("file_display_name"),
|
||||
|
||||
@@ -118,12 +118,7 @@ def instantiate_connector(
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credentials = connector.load_credentials(credential_json)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -108,7 +107,7 @@ def _process_file(
|
||||
# These metadata items are not settable by the user
|
||||
source_type = onyx_metadata.source_type or DocumentSource.FILE
|
||||
|
||||
doc_id = onyx_metadata.document_id or f"FILE_CONNECTOR__{file_id}"
|
||||
doc_id = f"FILE_CONNECTOR__{file_id}"
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
@@ -241,49 +240,29 @@ class LocalFileConnector(LoadConnector):
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
file_names: list[str] | None = None, # noqa: ARG002
|
||||
zip_metadata_file_id: str | None = None,
|
||||
zip_metadata: dict[str, Any] | None = None, # Deprecated, for backwards compat
|
||||
zip_metadata: dict[str, Any] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [str(loc) for loc in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
self._zip_metadata_file_id = zip_metadata_file_id
|
||||
self._zip_metadata_deprecated = zip_metadata
|
||||
self.zip_metadata = zip_metadata or {}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
|
||||
return None
|
||||
|
||||
def _get_file_metadata(self, file_name: str) -> dict[str, Any]:
|
||||
return self.zip_metadata.get(file_name, {}) or self.zip_metadata.get(
|
||||
os.path.basename(file_name), {}
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Iterates over each file path, fetches from Postgres, tries to parse text
|
||||
or images, and yields Document batches.
|
||||
"""
|
||||
# Load metadata dict at start (from file store or deprecated inline format)
|
||||
zip_metadata: dict[str, Any] = {}
|
||||
if self._zip_metadata_file_id:
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
metadata_io = file_store.read_file(
|
||||
file_id=self._zip_metadata_file_id, mode="b"
|
||||
)
|
||||
metadata_bytes = metadata_io.read()
|
||||
loaded_metadata = json.loads(metadata_bytes)
|
||||
if isinstance(loaded_metadata, list):
|
||||
zip_metadata = {d["filename"]: d for d in loaded_metadata}
|
||||
else:
|
||||
zip_metadata = loaded_metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata from file store: {e}")
|
||||
elif self._zip_metadata_deprecated:
|
||||
logger.warning(
|
||||
"Using deprecated inline zip_metadata dict. "
|
||||
"Re-upload files to use the new file store format."
|
||||
)
|
||||
zip_metadata = self._zip_metadata_deprecated
|
||||
|
||||
documents: list[Document | HierarchyNode] = []
|
||||
|
||||
for file_id in self.file_locations:
|
||||
@@ -294,9 +273,7 @@ class LocalFileConnector(LoadConnector):
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = zip_metadata.get(
|
||||
file_record.display_name, {}
|
||||
) or zip_metadata.get(os.path.basename(file_record.display_name), {})
|
||||
metadata = self._get_file_metadata(file_record.display_name)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
@@ -321,6 +298,7 @@ if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[os.environ["TEST_FILE"]],
|
||||
file_names=[os.environ["TEST_FILE"]],
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
|
||||
doc_batches = connector.load_from_state()
|
||||
|
||||
@@ -523,22 +523,6 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def fetch_configured_repos(self) -> list[Repository.Repository]:
|
||||
"""
|
||||
Fetch the configured repositories based on the connector settings.
|
||||
|
||||
Returns:
|
||||
list[Repository.Repository]: The configured repositories.
|
||||
"""
|
||||
assert self.github_client is not None # mypy
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
return self.get_github_repos(self.github_client)
|
||||
else:
|
||||
return [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
return self.get_all_repos(self.github_client)
|
||||
|
||||
def _pull_requests_func(
|
||||
self, repo: Repository.Repository
|
||||
) -> Callable[[], PaginatedList[PullRequest]]:
|
||||
@@ -567,7 +551,17 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = self.fetch_configured_repos()
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self.get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self.get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
@@ -474,9 +483,8 @@ class ConnectorStopSignal(Exception):
|
||||
|
||||
|
||||
class OnyxMetadata(BaseModel):
|
||||
# Careful overriding the document_id, may cause visual issues in the UI.
|
||||
# Kept here for API based use cases mostly
|
||||
document_id: str | None = None
|
||||
# Note that doc_id cannot be overriden here as it may cause issues
|
||||
# with the display functionalities in the UI. Ask @chris if clarification is needed.
|
||||
source_type: DocumentSource | None = None
|
||||
link: str | None = None
|
||||
file_display_name: str | None = None
|
||||
|
||||
@@ -104,48 +104,21 @@ class CertificateData(BaseModel):
|
||||
thumbprint: str
|
||||
|
||||
|
||||
# TODO(Evan): Remove this once we have a proper token refresh mechanism.
|
||||
def _clear_cached_token(query_obj: ClientQuery) -> bool:
|
||||
"""Clear the cached access token on the query object's ClientContext so
|
||||
the next request re-invokes the token callback and gets a fresh token.
|
||||
|
||||
The office365 library's AuthenticationContext.with_access_token() caches
|
||||
the token in ``_cached_token`` and never refreshes it. Setting it to
|
||||
``None`` forces re-acquisition on the next request.
|
||||
|
||||
Returns True if the token was successfully cleared."""
|
||||
ctx = getattr(query_obj, "context", query_obj)
|
||||
auth_ctx = getattr(ctx, "authentication_context", None)
|
||||
if auth_ctx is not None and hasattr(auth_ctx, "_cached_token"):
|
||||
auth_ctx._cached_token = None
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def sleep_and_retry(
|
||||
query_obj: ClientQuery, method_name: str, max_retries: int = 3
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a SharePoint query with retry logic for rate limiting
|
||||
and automatic token refresh on 401 Unauthorized.
|
||||
Execute a SharePoint query with retry logic for rate limiting.
|
||||
"""
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return query_obj.execute_query()
|
||||
except ClientRequestException as e:
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
|
||||
# 401 — token expired. Clear the cached token and retry immediately.
|
||||
if status == 401 and attempt < max_retries:
|
||||
cleared = _clear_cached_token(query_obj)
|
||||
logger.warning(
|
||||
f"Token expired on {method_name}, attempt {attempt + 1}/{max_retries + 1}, "
|
||||
f"cleared cached token={cleared}, retrying"
|
||||
)
|
||||
continue
|
||||
|
||||
# 429 / 503 — rate limit or transient error. Back off and retry.
|
||||
if status in (429, 503) and attempt < max_retries:
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code in [429, 503]
|
||||
and attempt < max_retries
|
||||
):
|
||||
logger.warning(
|
||||
f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying"
|
||||
)
|
||||
@@ -158,15 +131,13 @@ def sleep_and_retry(
|
||||
|
||||
logger.info(f"Sleeping for {sleep_time} seconds before retry")
|
||||
time.sleep(sleep_time)
|
||||
continue
|
||||
|
||||
# Non-retryable error or retries exhausted — log details and raise.
|
||||
if e.response is not None:
|
||||
logger.error(
|
||||
f"SharePoint request failed for {method_name}: "
|
||||
f"status={status}, "
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
# Either not a rate limit error, or we've exhausted retries
|
||||
if e.response is not None and e.response.status_code == 429:
|
||||
logger.error(
|
||||
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
@@ -308,18 +308,6 @@ def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
|
||||
return None
|
||||
|
||||
|
||||
def _bot_inclusive_msg_filter(
|
||||
message: MessageType,
|
||||
) -> SlackMessageFilterReason | None:
|
||||
"""Like default_msg_filter but allows bot/app messages through.
|
||||
Only filters out disallowed subtypes (channel_join, channel_leave, etc.).
|
||||
"""
|
||||
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||
return SlackMessageFilterReason.DISALLOWED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def filter_channels(
|
||||
all_channels: list[ChannelType],
|
||||
channels_to_connect: list[str] | None,
|
||||
@@ -666,18 +654,12 @@ class SlackConnector(
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
# if True, messages from bots/apps will be indexed instead of filtered out
|
||||
include_bot_messages: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = True,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.include_bot_messages = include_bot_messages
|
||||
self.msg_filter_func = (
|
||||
_bot_inclusive_msg_filter if include_bot_messages else default_msg_filter
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
@@ -857,7 +839,6 @@ class SlackConnector(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
)
|
||||
@@ -945,7 +926,6 @@ class SlackConnector(
|
||||
|
||||
try:
|
||||
num_bot_filtered_messages = 0
|
||||
num_other_filtered_messages = 0
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = str(end)
|
||||
@@ -1004,7 +984,6 @@ class SlackConnector(
|
||||
user_cache=self.user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
channel_access=checkpoint.current_channel_access,
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1024,13 +1003,7 @@ class SlackConnector(
|
||||
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif processed_slack_message.filter_reason:
|
||||
if (
|
||||
processed_slack_message.filter_reason
|
||||
== SlackMessageFilterReason.BOT
|
||||
):
|
||||
num_bot_filtered_messages += 1
|
||||
else:
|
||||
num_other_filtered_messages += 1
|
||||
num_bot_filtered_messages += 1
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
@@ -1050,14 +1023,10 @@ class SlackConnector(
|
||||
range_total = 1
|
||||
range_percent_complete = range_complete / range_total * 100.0
|
||||
|
||||
num_filtered = num_bot_filtered_messages + num_other_filtered_messages
|
||||
log_func = logger.warning if num_bot_filtered_messages > 0 else logger.info
|
||||
log_func(
|
||||
logger.info(
|
||||
f"Message processing stats: "
|
||||
f"batch_len={len(message_batch)} "
|
||||
f"batch_yielded={num_threads_processed} "
|
||||
f"filtered={num_filtered} "
|
||||
f"(bot={num_bot_filtered_messages} other={num_other_filtered_messages}) "
|
||||
f"total_threads_seen={len(seen_thread_ts)}"
|
||||
)
|
||||
|
||||
@@ -1071,8 +1040,7 @@ class SlackConnector(
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_oldest
|
||||
|
||||
# bypass channels where the first set of messages seen are all
|
||||
# filtered (bots + disallowed subtypes like channel_join)
|
||||
# bypass channels where the first set of messages seen are all bots
|
||||
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch
|
||||
# we shouldn't skip based on a small sampling of messages
|
||||
if (
|
||||
@@ -1080,7 +1048,7 @@ class SlackConnector(
|
||||
and len(message_batch) > SlackConnector.BOT_CHANNEL_MIN_BATCH_SIZE
|
||||
):
|
||||
if (
|
||||
num_filtered
|
||||
num_bot_filtered_messages
|
||||
> SlackConnector.BOT_CHANNEL_PERCENTAGE_THRESHOLD
|
||||
* len(message_batch)
|
||||
):
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.tracing.llm_utils import record_llm_span_output
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -201,8 +201,8 @@ def extract_date_range_from_query(
|
||||
llm=llm, flow="slack_date_extraction", input_messages=[prompt_msg]
|
||||
) as span_generation:
|
||||
llm_response = llm.invoke(prompt_msg)
|
||||
record_llm_response(span_generation, llm_response)
|
||||
response = llm_response_to_string(llm_response)
|
||||
record_llm_span_output(span_generation, response, llm_response.usage)
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -606,8 +606,8 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
llm=llm, flow="slack_query_expansion", input_messages=[prompt]
|
||||
) as span_generation:
|
||||
llm_response = llm.invoke(prompt)
|
||||
record_llm_response(span_generation, llm_response)
|
||||
response = llm_response_to_string(llm_response)
|
||||
record_llm_span_output(span_generation, response, llm_response.usage)
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
|
||||
32
backend/onyx/context/search/search_settings.py
Normal file
32
backend/onyx/context/search/search_settings.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Note, this file and all SavedSearchSettings things are not being used in live code paths (at least at the time of this comment)
|
||||
# Kept around as it may be useful in the future
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_SEARCH_SETTINGS
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_kv_search_settings() -> SavedSearchSettings | None:
|
||||
"""Get all user configured search settings which affect the search pipeline
|
||||
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
|
||||
|
||||
Note: for now we can't cache this value because if the API server is scaled, the cache could be out of sync
|
||||
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
|
||||
reddis then it will be ok. Until then this has some performance implications (though minor)
|
||||
"""
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
|
||||
except KvKeyNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading search settings: {e}")
|
||||
# Wiping it so that next server startup, it can load the defaults
|
||||
# or the user can set it via the API/UI
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
return None
|
||||
@@ -270,8 +270,6 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -299,21 +297,14 @@ def alter_credential(
|
||||
|
||||
credential.name = name
|
||||
|
||||
# Get existing credential_json and merge with new values
|
||||
existing_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential.credential_json = { # type: ignore[assignment]
|
||||
**existing_json,
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
**credential_json,
|
||||
}
|
||||
|
||||
credential.user_id = user.id
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -327,12 +318,10 @@ def update_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_data.credential_json # type: ignore[assignment]
|
||||
credential.user_id = user.id if user is not None else None
|
||||
credential.credential_json = credential_data.credential_json
|
||||
credential.user_id = user.id
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -346,10 +335,8 @@ def update_credential_json(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -359,7 +346,7 @@ def backend_update_credential_json(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""This should not be used in any flows involving the frontend or users"""
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -454,12 +441,7 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if first_credential is not None:
|
||||
credential_json_value = (
|
||||
first_credential.credential_json.get_value(apply_mask=False)
|
||||
if first_credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if credential_json_value != {} or first_credential.user is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
@@ -495,13 +477,8 @@ def delete_service_account_credentials(
|
||||
) -> None:
|
||||
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if (
|
||||
credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
and credential.source == source
|
||||
):
|
||||
db_session.delete(credential)
|
||||
|
||||
@@ -6,8 +6,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -228,50 +226,6 @@ def get_documents_by_ids(
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_documents_by_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
creator_id: UUID | None = None,
|
||||
) -> list[DbDocument]:
|
||||
"""Get all documents associated with a specific source type.
|
||||
|
||||
This queries through the connector relationship to find all documents
|
||||
that were indexed by connectors of the given source type.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
source: The document source type to filter by
|
||||
creator_id: If provided, only return documents from connectors
|
||||
created by this user. Filters via ConnectorCredentialPair.
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.where(Connector.source == source)
|
||||
)
|
||||
if creator_id is not None:
|
||||
stmt = stmt.where(ConnectorCredentialPair.creator_id == creator_id)
|
||||
stmt = stmt.distinct()
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
return list(documents)
|
||||
|
||||
|
||||
def _apply_last_updated_cursor_filter(
|
||||
stmt: Select,
|
||||
cursor_last_modified: datetime | None,
|
||||
@@ -1573,40 +1527,3 @@ def get_document_kg_entities_and_relationships(
|
||||
def get_num_chunks_for_document(db_session: Session, document_id: str) -> int:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none() or 0
|
||||
|
||||
|
||||
def update_document_metadata__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
doc_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the doc_metadata field for a document.
|
||||
|
||||
Note: Does not commit. Caller is responsible for committing.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
document_id: The ID of the document to update
|
||||
doc_metadata: The new metadata dictionary to set
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(doc_metadata=doc_metadata)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_document_by_id__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> None:
|
||||
"""Delete a single document and its connector credential pair relationships.
|
||||
|
||||
Note: Does not commit. Caller is responsible for committing.
|
||||
|
||||
This uses delete_documents_complete__no_commit which handles
|
||||
all foreign key relationships (KG entities, relationships, chunk stats,
|
||||
cc pair associations, feedback, tags).
|
||||
"""
|
||||
delete_documents_complete__no_commit(db_session, [document_id])
|
||||
|
||||
@@ -60,8 +60,7 @@ class ProcessingMode(str, PyEnum):
|
||||
"""Determines how documents are processed after fetching."""
|
||||
|
||||
REGULAR = "REGULAR" # Full pipeline: chunk → embed → Vespa
|
||||
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only (JSON documents)
|
||||
RAW_BINARY = "RAW_BINARY" # Write raw binary to S3 (no text extraction)
|
||||
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only
|
||||
|
||||
|
||||
class SyncType(str, PyEnum):
|
||||
@@ -198,12 +197,6 @@ class ThemePreference(str, PyEnum):
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class DefaultAppMode(str, PyEnum):
|
||||
AUTO = "AUTO"
|
||||
CHAT = "CHAT"
|
||||
SEARCH = "SEARCH"
|
||||
|
||||
|
||||
class SwitchoverType(str, PyEnum):
|
||||
REINDEX = "reindex"
|
||||
ACTIVE_ONLY = "active_only"
|
||||
|
||||
@@ -111,7 +111,7 @@ def update_federated_connector_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token = token # type: ignore[assignment]
|
||||
existing_token.token = token
|
||||
existing_token.expires_at = expires_at
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
@@ -267,13 +267,7 @@ def update_federated_connector(
|
||||
# Use provided credentials if updating them, otherwise use existing credentials
|
||||
# This is needed to instantiate the connector for config validation when only config is being updated
|
||||
creds_to_use = (
|
||||
credentials
|
||||
if credentials is not None
|
||||
else (
|
||||
federated_connector.credentials.get_value(apply_mask=False)
|
||||
if federated_connector.credentials
|
||||
else {}
|
||||
)
|
||||
credentials if credentials is not None else federated_connector.credentials
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
@@ -284,7 +278,7 @@ def update_federated_connector(
|
||||
raise ValueError(
|
||||
f"Invalid credentials for federated connector source: {federated_connector.source}"
|
||||
)
|
||||
federated_connector.credentials = credentials # type: ignore[assignment]
|
||||
federated_connector.credentials = credentials
|
||||
|
||||
if config is not None:
|
||||
# Validate config using connector-specific validation
|
||||
|
||||
@@ -231,11 +231,9 @@ def upsert_llm_provider(
|
||||
# Set to None if the dict is empty after filtering
|
||||
custom_config = custom_config or None
|
||||
|
||||
api_base = llm_provider_upsert_request.api_base or None
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
@@ -430,7 +428,7 @@ def fetch_existing_models(
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
flow_types: list[LLMModelFlowType],
|
||||
flow_type_filter: list[LLMModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
@@ -438,30 +436,27 @@ def fetch_existing_llm_providers(
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flow_types: List of flow types to filter by
|
||||
flow_type_filter: List of flow types to filter by, empty list for no filter
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.distinct()
|
||||
)
|
||||
stmt = select(LLMProviderModel)
|
||||
|
||||
if flow_type_filter:
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
|
||||
.distinct()
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
@@ -724,13 +719,15 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
else:
|
||||
# Add new model - all models from GitHub config are visible
|
||||
new_model = ModelConfiguration(
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
name=model_config.name,
|
||||
display_name=model_config.display_name,
|
||||
model_name=model_config.name,
|
||||
supported_flows=[LLMModelFlowType.CHAT],
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -205,21 +204,6 @@ def remove_user_from_mcp_server(
|
||||
|
||||
|
||||
# MCPConnectionConfig operations
|
||||
def extract_connection_data(
|
||||
config: MCPConnectionConfig | None, apply_mask: bool = False
|
||||
) -> MCPConnectionData:
|
||||
"""Extract MCPConnectionData from a connection config, with proper typing.
|
||||
|
||||
This helper encapsulates the cast from the JSON column's dict[str, Any]
|
||||
to the typed MCPConnectionData structure.
|
||||
"""
|
||||
if config is None or config.config is None:
|
||||
return MCPConnectionData(headers={})
|
||||
if isinstance(config.config, SensitiveValue):
|
||||
return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask))
|
||||
return cast(MCPConnectionData, config.config)
|
||||
|
||||
|
||||
def get_connection_config_by_id(
|
||||
config_id: int, db_session: Session
|
||||
) -> MCPConnectionConfig:
|
||||
@@ -285,7 +269,7 @@ def update_connection_config(
|
||||
config = get_connection_config_by_id(config_id, db_session)
|
||||
|
||||
if config_data is not None:
|
||||
config.config = config_data # type: ignore[assignment]
|
||||
config.config = config_data
|
||||
# Force SQLAlchemy to detect the change by marking the field as modified
|
||||
flag_modified(config, "config")
|
||||
|
||||
@@ -303,7 +287,7 @@ def upsert_user_connection_config(
|
||||
existing_config = get_user_connection_config(server_id, user_email, db_session)
|
||||
|
||||
if existing_config:
|
||||
existing_config.config = config_data # type: ignore[assignment]
|
||||
existing_config.config = config_data
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return existing_config
|
||||
else:
|
||||
|
||||
@@ -1,163 +1,22 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
|
||||
MAX_MEMORIES_PER_USER = 10
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
email: str | None = None
|
||||
def get_memories(user: User, db_session: Session) -> list[str]:
|
||||
if not user.use_memories:
|
||||
return []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
|
||||
class UserMemoryContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
user_id: UUID | None = None
|
||||
user_info: UserInfo
|
||||
user_preferences: str | None = None
|
||||
memories: tuple[str, ...] = ()
|
||||
|
||||
def as_formatted_list(self) -> list[str]:
|
||||
"""Returns combined list of user info, preferences, and memories."""
|
||||
result = []
|
||||
if self.user_info.name:
|
||||
result.append(f"User's name: {self.user_info.name}")
|
||||
if self.user_info.role:
|
||||
result.append(f"User's role: {self.user_info.role}")
|
||||
if self.user_info.email:
|
||||
result.append(f"User's email: {self.user_info.email}")
|
||||
if self.user_preferences:
|
||||
result.append(f"User preferences: {self.user_preferences}")
|
||||
result.extend(self.memories)
|
||||
return result
|
||||
|
||||
def as_formatted_prompt(self) -> str:
|
||||
"""Returns structured prompt sections for the system prompt."""
|
||||
has_basic_info = (
|
||||
self.user_info.name or self.user_info.email or self.user_info.role
|
||||
)
|
||||
if not has_basic_info and not self.user_preferences and not self.memories:
|
||||
return ""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
|
||||
if self.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=self.user_info.name or "",
|
||||
user_email=self.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
|
||||
)
|
||||
|
||||
if self.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
return "".join(sections)
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
user_info = UserInfo(
|
||||
name=user.personal_name,
|
||||
role=user.personal_role,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
user_preferences = None
|
||||
if user.user_preferences:
|
||||
user_preferences = user.user_preferences
|
||||
user_info = [
|
||||
f"User's name: {user.personal_name}" if user.personal_name else "",
|
||||
f"User's role: {user.personal_role}" if user.personal_role else "",
|
||||
f"User's email: {user.email}" if user.email else "",
|
||||
]
|
||||
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc())
|
||||
select(Memory).where(Memory.user_id == user.id)
|
||||
).all()
|
||||
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
|
||||
|
||||
return UserMemoryContext(
|
||||
user_id=user.id,
|
||||
user_info=user_info,
|
||||
user_preferences=user_preferences,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
|
||||
return user_info + memories
|
||||
|
||||
@@ -75,7 +75,6 @@ from onyx.db.enums import (
|
||||
MCPServerStatus,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
SwitchoverType,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
@@ -96,10 +95,10 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
@@ -123,35 +122,18 @@ class EncryptedString(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
self, value: str | None, dialect: Dialect # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw strings and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
return encrypt_string_to_bytes(value)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> SensitiveValue[str] | None:
|
||||
) -> str | None:
|
||||
if value is not None:
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=False,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
return decrypt_bytes_to_string(value)
|
||||
return value
|
||||
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
@@ -160,38 +142,20 @@ class EncryptedJson(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self,
|
||||
value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None,
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
self, value: dict | None, dialect: Dialect # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> SensitiveValue[dict[str, Any]] | None:
|
||||
) -> dict | None:
|
||||
if value is not None:
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
json_str = decrypt_bytes_to_string(value)
|
||||
return json.loads(json_str)
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -248,19 +212,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
default=None,
|
||||
)
|
||||
chat_background: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
default_app_mode: Mapped[DefaultAppMode] = mapped_column(
|
||||
Enum(DefaultAppMode, native_enum=False),
|
||||
nullable=False,
|
||||
default=DefaultAppMode.CHAT,
|
||||
)
|
||||
# personalization fields are exposed via the chat user settings "Personalization" tab
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
enable_memory_tool: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
@@ -321,7 +276,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
order_by="desc(Memory.id)",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
"OAuthUserToken",
|
||||
@@ -1037,31 +991,6 @@ class OpenSearchTenantMigrationRecord(Base):
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
# Opaque continuation token from Vespa's Visit API.
|
||||
# NULL means "not started" or "visit completed".
|
||||
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
total_chunks_migrated: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
total_chunks_errored: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
total_chunks_in_vespa: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
migration_completed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
@@ -1826,9 +1755,7 @@ class Credential(Base):
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson()
|
||||
)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
@@ -1866,9 +1793,7 @@ class FederatedConnector(Base):
|
||||
source: Mapped[FederatedConnectorSource] = mapped_column(
|
||||
Enum(FederatedConnectorSource, native_enum=False)
|
||||
)
|
||||
credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
credentials: Mapped[dict[str, str]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB(), default=dict, nullable=False, server_default="{}"
|
||||
)
|
||||
@@ -1895,9 +1820,7 @@ class FederatedConnectorOAuthToken(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
@@ -2041,9 +1964,7 @@ class SearchSettings(Base):
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
if self.cloud_provider is None or self.cloud_provider.api_key is None:
|
||||
return None
|
||||
return self.cloud_provider.api_key.get_value(apply_mask=False)
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def large_chunks_enabled(self) -> bool:
|
||||
@@ -2805,9 +2726,7 @@ class LLMProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider: Mapped[str] = mapped_column(String)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# custom configs that should be passed to the LLM provider at inference time
|
||||
@@ -2960,7 +2879,7 @@ class CloudEmbeddingProvider(Base):
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString())
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -2979,9 +2898,7 @@ class InternetSearchProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
@@ -3003,9 +2920,7 @@ class InternetContentProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
config: Mapped[WebContentProviderConfig | None] = mapped_column(
|
||||
PydanticType(WebContentProviderConfig), nullable=True
|
||||
)
|
||||
@@ -3149,12 +3064,8 @@ class OAuthConfig(Base):
|
||||
token_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
# Client credentials (encrypted)
|
||||
client_id: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
client_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
|
||||
# Optional configurations
|
||||
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
@@ -3201,9 +3112,7 @@ class OAuthUserToken(Base):
|
||||
# "expires_at": 1234567890, # Unix timestamp, optional
|
||||
# "scope": "repo user" # Optional
|
||||
# }
|
||||
token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -3536,15 +3445,9 @@ class SlackBot(Base):
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
app_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
user_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
|
||||
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
|
||||
"SlackChannelConfig",
|
||||
@@ -3565,9 +3468,7 @@ class DiscordBotConfig(Base):
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
@@ -3723,9 +3624,7 @@ class KVStore(Base):
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=True
|
||||
)
|
||||
encrypted_value: Mapped[JSON_ro] = mapped_column(EncryptedJson(), nullable=True)
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
@@ -4445,7 +4344,7 @@ class MCPConnectionConfig(Base):
|
||||
# "registration_access_token": "<token>", # For managing registration
|
||||
# "registration_client_uri": "<uri>", # For managing registration
|
||||
# }
|
||||
config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
config: Mapped[MCPConnectionData] = mapped_column(
|
||||
EncryptedJson(), nullable=False, default=dict
|
||||
)
|
||||
|
||||
|
||||
@@ -87,13 +87,13 @@ def update_oauth_config(
|
||||
if token_url is not None:
|
||||
oauth_config.token_url = token_url
|
||||
if clear_client_id:
|
||||
oauth_config.client_id = "" # type: ignore[assignment]
|
||||
oauth_config.client_id = ""
|
||||
elif client_id is not None:
|
||||
oauth_config.client_id = client_id # type: ignore[assignment]
|
||||
oauth_config.client_id = client_id
|
||||
if clear_client_secret:
|
||||
oauth_config.client_secret = "" # type: ignore[assignment]
|
||||
oauth_config.client_secret = ""
|
||||
elif client_secret is not None:
|
||||
oauth_config.client_secret = client_secret # type: ignore[assignment]
|
||||
oauth_config.client_secret = client_secret
|
||||
if scopes is not None:
|
||||
oauth_config.scopes = scopes
|
||||
if additional_params is not None:
|
||||
@@ -154,7 +154,7 @@ def upsert_user_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token_data = token_data # type: ignore[assignment]
|
||||
existing_token.token_data = token_data
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
else:
|
||||
|
||||
@@ -4,9 +4,6 @@ This module provides functions to track the progress of migrating documents
|
||||
from Vespa to OpenSearch.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
@@ -15,14 +12,10 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
from onyx.db.models import OpenSearchTenantMigrationRecord
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -183,7 +176,7 @@ def try_insert_opensearch_tenant_migration_record_with_commit(
|
||||
) -> None:
|
||||
"""Tries to insert the singleton row on OpenSearchTenantMigrationRecord.
|
||||
|
||||
Does nothing if the row already exists.
|
||||
If the row already exists, does nothing.
|
||||
"""
|
||||
stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing(
|
||||
index_elements=[text("(true)")]
|
||||
@@ -197,14 +190,25 @@ def increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
) -> None:
|
||||
"""Increments the number of times observed no additional docs to migrate.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
|
||||
starting count, and if the row already exists, increments the count.
|
||||
|
||||
Used to track when to stop the migration task.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.num_times_observed_no_additional_docs_to_migrate += 1
|
||||
stmt = (
|
||||
insert(OpenSearchTenantMigrationRecord)
|
||||
.values(num_times_observed_no_additional_docs_to_migrate=1)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[text("(true)")],
|
||||
set_={
|
||||
"num_times_observed_no_additional_docs_to_migrate": (
|
||||
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_migrate
|
||||
+ 1
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -215,14 +219,25 @@ def increment_num_times_observed_no_additional_docs_to_populate_migration_table_
|
||||
Increments the number of times observed no additional docs to populate the
|
||||
migration table.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
|
||||
starting count, and if the row already exists, increments the count.
|
||||
|
||||
Used to track when to stop the migration check task.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.num_times_observed_no_additional_docs_to_populate_migration_table += 1
|
||||
stmt = (
|
||||
insert(OpenSearchTenantMigrationRecord)
|
||||
.values(num_times_observed_no_additional_docs_to_populate_migration_table=1)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[text("(true)")],
|
||||
set_={
|
||||
"num_times_observed_no_additional_docs_to_populate_migration_table": (
|
||||
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_populate_migration_table
|
||||
+ 1
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -239,167 +254,3 @@ def should_document_migration_be_permanently_failed(
|
||||
>= TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_vespa_visit_state(
|
||||
db_session: Session,
|
||||
) -> tuple[str | None, int]:
|
||||
"""Gets the current Vespa migration state from the tenant migration record.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Returns:
|
||||
Tuple of (continuation_token, total_chunks_migrated). continuation_token
|
||||
is None if not started or completed.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
return (
|
||||
record.vespa_visit_continuation_token,
|
||||
record.total_chunks_migrated,
|
||||
)
|
||||
|
||||
|
||||
def update_vespa_visit_progress_with_commit(
|
||||
db_session: Session,
|
||||
continuation_token: str | None,
|
||||
chunks_processed: int,
|
||||
chunks_errored: int,
|
||||
) -> None:
|
||||
"""Updates the Vespa migration progress and commits.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
continuation_token: The new continuation token. None means the visit
|
||||
is complete.
|
||||
chunks_processed: Number of chunks processed in this batch (added to
|
||||
the running total).
|
||||
chunks_errored: Number of chunks errored in this batch (added to the
|
||||
running errored total).
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.vespa_visit_continuation_token = continuation_token
|
||||
record.total_chunks_migrated += chunks_processed
|
||||
record.total_chunks_errored += chunks_errored
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_migration_completed_time_if_not_set_with_commit(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Marks the migration completed time if not set.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
if record.migration_completed_at is not None:
|
||||
return
|
||||
record.migration_completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def build_sanitized_to_original_doc_id_mapping(
|
||||
db_session: Session,
|
||||
) -> dict[str, str]:
|
||||
"""Pre-computes a mapping of sanitized -> original document IDs.
|
||||
|
||||
Only includes documents whose ID contains single quotes (the only character
|
||||
that gets sanitized by replace_invalid_doc_id_characters). For all other
|
||||
documents, sanitized == original and no mapping entry is needed.
|
||||
|
||||
Scans over all documents.
|
||||
|
||||
Checks if the sanitized ID already exists as a genuine separate document in
|
||||
the Document table. If so, raises as there is no way of resolving the
|
||||
conflict in the migration. The user will need to reindex.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
Dict mapping sanitized_id -> original_id, only for documents where
|
||||
the IDs differ. Empty dict means no documents have single quotes
|
||||
in their IDs.
|
||||
"""
|
||||
# Find all documents with single quotes in their ID.
|
||||
stmt = select(Document.id).where(Document.id.contains("'"))
|
||||
ids_with_quotes = list(db_session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for original_id in ids_with_quotes:
|
||||
sanitized_id = replace_invalid_doc_id_characters(original_id)
|
||||
if sanitized_id != original_id:
|
||||
result[sanitized_id] = original_id
|
||||
|
||||
# See if there are any documents whose ID is a sanitized ID of another
|
||||
# document. If there is even one match, we cannot proceed.
|
||||
stmt = select(Document.id).where(Document.id.in_(result.keys()))
|
||||
ids_with_matches = list(db_session.scalars(stmt).all())
|
||||
if ids_with_matches:
|
||||
raise RuntimeError(
|
||||
f"Documents with IDs {ids_with_matches} have sanitized IDs that match other documents. "
|
||||
"This is not supported and the user will need to reindex."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_opensearch_migration_state(
|
||||
db_session: Session,
|
||||
) -> tuple[int, datetime | None, datetime | None]:
|
||||
"""Returns the state of the Vespa to OpenSearch migration.
|
||||
|
||||
If the tenant migration record is not found, returns defaults of 0, None,
|
||||
None.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return 0, None, None
|
||||
return (
|
||||
record.total_chunks_migrated,
|
||||
record.created_at,
|
||||
record.migration_completed_at,
|
||||
)
|
||||
|
||||
|
||||
def get_opensearch_retrieval_state(
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Returns the state of the OpenSearch retrieval.
|
||||
|
||||
If the tenant migration record is not found, defaults to
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
return record.enable_opensearch_retrieval
|
||||
|
||||
|
||||
def set_enable_opensearch_retrieval_with_commit(
|
||||
db_session: Session,
|
||||
enable: bool,
|
||||
) -> None:
|
||||
"""Sets the enable_opensearch_retrieval flag on the singleton record.
|
||||
|
||||
Creates the record if it doesn't exist yet.
|
||||
"""
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.enable_opensearch_retrieval = enable
|
||||
db_session.commit()
|
||||
|
||||
@@ -43,9 +43,9 @@ def update_slack_bot(
|
||||
# update the app
|
||||
slack_bot.name = name
|
||||
slack_bot.enabled = enabled
|
||||
slack_bot.bot_token = bot_token # type: ignore[assignment]
|
||||
slack_bot.app_token = app_token # type: ignore[assignment]
|
||||
slack_bot.user_token = user_token # type: ignore[assignment]
|
||||
slack_bot.bot_token = bot_token
|
||||
slack_bot.app_token = app_token
|
||||
slack_bot.user_token = user_token
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -55,8 +55,6 @@ def get_tools(
|
||||
# To avoid showing rows that have JSON literal `null` stored in the column to the user.
|
||||
# tools from mcp servers will not have an openapi schema but it has `null`, so we need to exclude them.
|
||||
func.jsonb_typeof(Tool.openapi_schema) == "object",
|
||||
# Exclude built-in tools that happen to have an openapi_schema
|
||||
Tool.in_code_tool_id.is_(None),
|
||||
)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
@@ -9,13 +9,11 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -155,29 +153,13 @@ def update_user_chat_background(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_default_app_mode(
|
||||
user_id: UUID,
|
||||
default_app_mode: DefaultAppMode,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's default app mode setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(default_app_mode=default_app_mode)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_personalization(
|
||||
user_id: UUID,
|
||||
*,
|
||||
personal_name: str | None,
|
||||
personal_role: str | None,
|
||||
use_memories: bool,
|
||||
enable_memory_tool: bool,
|
||||
memories: list[MemoryItem],
|
||||
user_preferences: str | None,
|
||||
memories: list[str],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
@@ -187,39 +169,14 @@ def update_user_personalization(
|
||||
personal_name=personal_name,
|
||||
personal_role=personal_role,
|
||||
use_memories=use_memories,
|
||||
enable_memory_tool=enable_memory_tool,
|
||||
user_preferences=user_preferences,
|
||||
)
|
||||
)
|
||||
|
||||
# ID-based upsert: use real DB IDs from the frontend to match memories.
|
||||
incoming_ids = {m.id for m in memories if m.id is not None}
|
||||
db_session.execute(delete(Memory).where(Memory.user_id == user_id))
|
||||
|
||||
# Delete existing rows not in the incoming set (scoped to user_id)
|
||||
existing_memories = list(
|
||||
db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
)
|
||||
existing_ids = {mem.id for mem in existing_memories}
|
||||
ids_to_delete = existing_ids - incoming_ids
|
||||
if ids_to_delete:
|
||||
db_session.execute(
|
||||
delete(Memory).where(
|
||||
Memory.id.in_(ids_to_delete),
|
||||
Memory.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Update existing rows whose IDs match
|
||||
existing_by_id = {mem.id: mem for mem in existing_memories}
|
||||
for item in memories:
|
||||
if item.id is not None and item.id in existing_by_id:
|
||||
existing_by_id[item.id].memory_text = item.content
|
||||
|
||||
# Create new rows for items without an ID
|
||||
new_items = [m for m in memories if m.id is None]
|
||||
if new_items:
|
||||
if memories:
|
||||
db_session.add_all(
|
||||
[Memory(user_id=user_id, memory_text=item.content) for item in new_items]
|
||||
[Memory(user_id=user_id, memory_text=memory) for memory in memories]
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -73,8 +73,7 @@ def _apply_search_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
provider.api_key = api_key
|
||||
|
||||
|
||||
def upsert_web_search_provider(
|
||||
@@ -229,8 +228,7 @@ def _apply_content_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
provider.api_key = api_key
|
||||
|
||||
|
||||
def upsert_web_content_provider(
|
||||
|
||||
@@ -17,7 +17,6 @@ from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import run_llm_step_pkt_generator
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
@@ -110,7 +109,6 @@ def generate_final_report(
|
||||
user_identity: LLMUserIdentity | None,
|
||||
saved_reasoning: str | None = None,
|
||||
pre_answer_processing_time: float | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> bool:
|
||||
"""Generate the final research report.
|
||||
|
||||
@@ -132,7 +130,7 @@ def generate_final_report(
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=final_reminder,
|
||||
token_count=token_counter(final_reminder),
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
final_report_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
@@ -141,7 +139,6 @@ def generate_final_report(
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
@@ -197,7 +194,6 @@ def run_deep_research_llm_loop(
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_deep_research_llm_loop",
|
||||
@@ -260,7 +256,6 @@ def run_deep_research_llm_loop(
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
# Calculate tool processing duration for clarification step
|
||||
@@ -309,8 +304,6 @@ def run_deep_research_llm_loop(
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
# Note this is fine to use a USER message type here as it can just be interpretered as a
|
||||
# user's message directly to the LLM.
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
@@ -324,7 +317,6 @@ def run_deep_research_llm_loop(
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
@@ -450,7 +442,6 @@ def run_deep_research_llm_loop(
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=elapsed_seconds,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -459,9 +450,11 @@ def run_deep_research_llm_loop(
|
||||
first_cycle_reminder_message = ChatMessageSimple(
|
||||
message=FIRST_CYCLE_REMINDER,
|
||||
token_count=FIRST_CYCLE_REMINDER_TOKENS,
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
first_cycle_tokens = FIRST_CYCLE_REMINDER_TOKENS
|
||||
else:
|
||||
first_cycle_tokens = 0
|
||||
first_cycle_reminder_message = None
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
@@ -484,13 +477,15 @@ def run_deep_research_llm_loop(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
available_tokens=available_tokens - first_cycle_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
if first_cycle_reminder_message is not None:
|
||||
truncated_message_history.append(first_cycle_reminder_message)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
@@ -554,7 +549,6 @@ def run_deep_research_llm_loop(
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -578,7 +572,6 @@ def run_deep_research_llm_loop(
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -651,7 +644,6 @@ def run_deep_research_llm_loop(
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (
|
||||
1 if report_reasoned else 0
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
"""A DocumentIndex implementation that raises on every operation.
|
||||
|
||||
Used as a safety net when DISABLE_VECTOR_DB is True. Any code path that
|
||||
accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import QueryExpansionType
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
VECTOR_DB_DISABLED_ERROR = (
|
||||
"Vector DB is disabled (DISABLE_VECTOR_DB=true). "
|
||||
"This operation requires a vector database."
|
||||
)
|
||||
|
||||
|
||||
class DisabledDocumentIndex(DocumentIndex):
|
||||
"""A DocumentIndex where every method raises RuntimeError.
|
||||
|
||||
Returned by the factory when DISABLE_VECTOR_DB is True so that any
|
||||
accidental vector-DB call surfaces immediately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str = "disabled",
|
||||
secondary_index_name: str | None = None,
|
||||
*args: Any, # noqa: ARG002
|
||||
**kwargs: Any, # noqa: ARG002
|
||||
) -> None:
|
||||
self.index_name = index_name
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Verifiable
|
||||
# ------------------------------------------------------------------
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
primary_embedding_dim: int, # noqa: ARG002
|
||||
primary_embedding_precision: EmbeddingPrecision, # noqa: ARG002
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
) -> None:
|
||||
# No-op: there are no indices to create when the vector DB is disabled.
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
indices: list[str], # noqa: ARG002, ARG004
|
||||
embedding_dims: list[int], # noqa: ARG002, ARG004
|
||||
embedding_precisions: list[EmbeddingPrecision], # noqa: ARG002, ARG004
|
||||
) -> None:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexable
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Deletable
|
||||
# ------------------------------------------------------------------
|
||||
def delete_single(
|
||||
self,
|
||||
doc_id: str, # noqa: ARG002
|
||||
*,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None, # noqa: ARG002
|
||||
) -> int:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Updatable
|
||||
# ------------------------------------------------------------------
|
||||
def update_single(
|
||||
self,
|
||||
doc_id: str, # noqa: ARG002
|
||||
*,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None, # noqa: ARG002
|
||||
fields: VespaDocumentFields | None, # noqa: ARG002
|
||||
user_fields: VespaDocumentUserFields | None, # noqa: ARG002
|
||||
) -> None:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# IdRetrievalCapable
|
||||
# ------------------------------------------------------------------
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest], # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
batch_retrieval: bool = False, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HybridCapable
|
||||
# ------------------------------------------------------------------
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str, # noqa: ARG002
|
||||
query_embedding: Embedding, # noqa: ARG002
|
||||
final_keywords: list[str] | None, # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
hybrid_alpha: float, # noqa: ARG002
|
||||
time_decay_multiplier: float, # noqa: ARG002
|
||||
num_to_retrieve: int, # noqa: ARG002
|
||||
ranking_profile_type: QueryExpansionType, # noqa: ARG002
|
||||
title_content_ratio: float | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AdminCapable
|
||||
# ------------------------------------------------------------------
|
||||
def admin_retrieval(
|
||||
self,
|
||||
query: str, # noqa: ARG002
|
||||
query_embedding: Embedding, # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
num_to_retrieve: int = 10, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RandomCapable
|
||||
# ------------------------------------------------------------------
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
num_to_retrieve: int = 10, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
@@ -1,11 +1,8 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
|
||||
from onyx.document_index.disabled import DisabledDocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
@@ -17,7 +14,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def get_default_document_index(
|
||||
search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
db_session: Session,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> DocumentIndex:
|
||||
"""Gets the default document index from env vars.
|
||||
@@ -31,24 +27,13 @@ def get_default_document_index(
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
secondary_index_name: str | None = None
|
||||
secondary_large_chunks_enabled: bool | None = None
|
||||
if secondary_search_settings:
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
if ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX:
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
@@ -84,24 +69,7 @@ def get_all_document_indices(
|
||||
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
assumed that the state of Vespa is more up-to-date than the state of
|
||||
OpenSearch.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return [
|
||||
DisabledDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user