mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
123 Commits
experiment
...
dependabot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43ed3e2a03 | ||
|
|
ddb14ec762 | ||
|
|
f31f589860 | ||
|
|
63b9a869af | ||
|
|
6aea36b573 | ||
|
|
3d8e8d0846 | ||
|
|
dea5be2185 | ||
|
|
d083973d4f | ||
|
|
df956888bf | ||
|
|
7c6062e7d5 | ||
|
|
89d2759021 | ||
|
|
d9feaf43a7 | ||
|
|
5bfffefa2f | ||
|
|
4d0b7e14d4 | ||
|
|
36c55d9e59 | ||
|
|
9f652108f9 | ||
|
|
d4e4c6b40e | ||
|
|
9c8deb5d0c | ||
|
|
58f57c43aa | ||
|
|
62106df753 | ||
|
|
45b3a5e945 | ||
|
|
e19a6b6789 | ||
|
|
2de7df4839 | ||
|
|
bd054bbad9 | ||
|
|
313e709d41 | ||
|
|
aeb1d6edac | ||
|
|
49a35f8aaa | ||
|
|
049e8ef0e2 | ||
|
|
3b61b495a3 | ||
|
|
5c5c9f0e1d | ||
|
|
f20d5c33b7 | ||
|
|
e898407f7b | ||
|
|
f802ff09a7 | ||
|
|
69ad712e09 | ||
|
|
98b69c0f2c | ||
|
|
1e5c87896f | ||
|
|
b6cc97a8c3 | ||
|
|
032fbf1058 | ||
|
|
fc32a9f92a | ||
|
|
9be13bbf63 | ||
|
|
9e7176eb82 | ||
|
|
c7faf8ce52 | ||
|
|
6230e36a63 | ||
|
|
7595b54f6b | ||
|
|
dc1bb426ee | ||
|
|
e9a0506183 | ||
|
|
4747c43889 | ||
|
|
27e676c48f | ||
|
|
6749f63f09 | ||
|
|
e404ffd443 | ||
|
|
c5b89b86c3 | ||
|
|
84bb3867b2 | ||
|
|
92cc1d83b5 | ||
|
|
e92d4a342f | ||
|
|
b4d596c957 | ||
|
|
d76d32003b | ||
|
|
007d2d109f | ||
|
|
08891b5242 | ||
|
|
846672a843 | ||
|
|
0f362457be | ||
|
|
283e8f4d3f | ||
|
|
fdf19d74bd | ||
|
|
7c702f8932 | ||
|
|
3fb06f6e8e | ||
|
|
9fcd999076 | ||
|
|
c937da65c4 | ||
|
|
abdbe89dd4 | ||
|
|
54f9c67522 | ||
|
|
31bcdc69ca | ||
|
|
b748e08029 | ||
|
|
11b279ad31 | ||
|
|
782082f818 | ||
|
|
c01b559bc6 | ||
|
|
3101a53855 | ||
|
|
ce6c210de1 | ||
|
|
15b372fea9 | ||
|
|
cf523cb467 | ||
|
|
344625b7e0 | ||
|
|
9bf8400cf8 | ||
|
|
09e86c2fda | ||
|
|
204328d52a | ||
|
|
3ce58c8450 | ||
|
|
67b5df255a | ||
|
|
33fa29e19f | ||
|
|
787f25a7c8 | ||
|
|
f10b994a27 | ||
|
|
d4089b1785 | ||
|
|
e122959854 | ||
|
|
93afb154ee | ||
|
|
e9be078268 | ||
|
|
61502751e8 | ||
|
|
cd26893b87 | ||
|
|
90dc6b16fa | ||
|
|
34b48763f4 | ||
|
|
094d7a2d02 | ||
|
|
faa97e92e8 | ||
|
|
358dc32fd2 | ||
|
|
f06465bfb2 | ||
|
|
8a51b00050 | ||
|
|
33de6dcd6a | ||
|
|
fe52f4e6d3 | ||
|
|
51de334732 | ||
|
|
cb72f84209 | ||
|
|
8b24c08467 | ||
|
|
0a1e043a97 | ||
|
|
466668fed5 | ||
|
|
41d105faa0 | ||
|
|
9e581f48e5 | ||
|
|
48d8e0955a | ||
|
|
a77780d67e | ||
|
|
d13511500c | ||
|
|
216d486323 | ||
|
|
a57d399ba5 | ||
|
|
07324ae0e4 | ||
|
|
c8ae07f7c2 | ||
|
|
f0fd19f110 | ||
|
|
6a62406042 | ||
|
|
d0be7dd914 | ||
|
|
6a045db72b | ||
|
|
e5e9dbe2f0 | ||
|
|
50e0a2cf90 | ||
|
|
50538ce5ac | ||
|
|
6fab7103bf |
16
.cursor/mcp.json
Normal file
16
.cursor/mcp.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"Playwright": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"@playwright/mcp"
|
||||
]
|
||||
},
|
||||
"Linear": {
|
||||
"url": "https://mcp.linear.app/mcp"
|
||||
},
|
||||
"Figma": {
|
||||
"url": "https://mcp.figma.com/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
6
.github/workflows/deployment.yml
vendored
6
.github/workflows/deployment.yml
vendored
@@ -91,8 +91,8 @@ jobs:
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Skip desktop builds on beta tags and nightly runs
|
||||
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
|
||||
# Only build desktop for semver tags (excluding beta)
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] && [[ "$IS_BETA" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
@@ -640,6 +640,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
@@ -721,6 +722,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
175
.github/workflows/pr-integration-tests.yml
vendored
175
.github/workflows/pr-integration-tests.yml
vendored
@@ -46,6 +46,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
editions: ${{ steps.set-editions.outputs.editions }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
@@ -56,7 +57,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -72,6 +73,16 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Determine editions to test
|
||||
id: set-editions
|
||||
run: |
|
||||
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
|
||||
if [ "${{ github.event_name }}" = "pull_request" ]; then
|
||||
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
@@ -267,7 +278,7 @@ jobs:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
@@ -275,6 +286,7 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -298,12 +310,11 @@ jobs:
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
EDITION: ${{ matrix.edition }}
|
||||
run: |
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -312,11 +323,20 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
# EE-only config
|
||||
if [ "$EDITION" = "ee" ]; then
|
||||
cat <<EOF >> deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -379,14 +399,14 @@ jobs:
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -444,10 +464,143 @@ jobs:
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "API server is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error; retrying..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
@@ -587,7 +740,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, multitenant-tests]
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
443
.github/workflows/pr-mit-integration-tests.yml
vendored
443
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -1,443 +0,0 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
outputs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Discover test directories
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
all_dirs=""
|
||||
for dir in $tests_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
|
||||
done
|
||||
for dir in $connector_dirs; do
|
||||
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
|
||||
done
|
||||
|
||||
# Remove trailing comma and wrap in array
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
integration-tests-mit:
|
||||
needs:
|
||||
[
|
||||
discover-test-dirs,
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
wait_for_service() {
|
||||
local url=$1
|
||||
local label=$2
|
||||
local timeout=${3:-300} # default 5 minutes
|
||||
local start_time
|
||||
start_time=$(date +%s)
|
||||
|
||||
while true; do
|
||||
local current_time
|
||||
current_time=$(date +%s)
|
||||
local elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
local response
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "${label} is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
|
||||
else
|
||||
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests-mit]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
269
.github/workflows/pr-playwright-tests.yml
vendored
269
.github/workflows/pr-playwright-tests.yml
vendored
@@ -52,6 +52,9 @@ env:
|
||||
MCP_SERVER_PUBLIC_HOST: host.docker.internal
|
||||
MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp
|
||||
|
||||
# Visual regression S3 bucket (shared across all jobs)
|
||||
PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts
|
||||
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on:
|
||||
@@ -239,6 +242,9 @@ jobs:
|
||||
playwright-tests:
|
||||
needs: [build-web-image, build-backend-image, build-model-server-image]
|
||||
name: Playwright Tests (${{ matrix.project }})
|
||||
permissions:
|
||||
id-token: write # Required for OIDC-based AWS credential exchange (S3 access)
|
||||
contents: read
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=8cpu-linux-arm64
|
||||
@@ -428,8 +434,6 @@ jobs:
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
run: |
|
||||
# Create test-results directory to ensure it exists for artifact upload
|
||||
mkdir -p test-results
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
@@ -437,9 +441,134 @@ jobs:
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/test-results/
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload screenshots
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshots/
|
||||
retention-days: 30
|
||||
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Determine baseline revision
|
||||
if: always()
|
||||
id: baseline-rev
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
BASE_REF: ${{ github.event.pull_request.base.ref }}
|
||||
MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }}
|
||||
GH_REF: ${{ github.ref }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ "${EVENT_NAME}" = "pull_request" ]; then
|
||||
# PRs compare against the base branch (e.g. main, release/2.5)
|
||||
echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT"
|
||||
elif [ "${EVENT_NAME}" = "merge_group" ]; then
|
||||
# Merge queue compares against the target branch (e.g. refs/heads/main -> main)
|
||||
echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT"
|
||||
elif [[ "${GH_REF}" == refs/tags/* ]]; then
|
||||
# Tag builds compare against the tag name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
# Push builds (main, release/*) compare against the branch name
|
||||
echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Generate screenshot diff report
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff compare \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}"
|
||||
|
||||
- name: Upload visual diff report to S3
|
||||
if: always()
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
echo "No summary file found — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
if [ "${HAS_DIFF}" != "true" ]; then
|
||||
echo "No visual differences for ${PROJECT} — skipping S3 upload."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \
|
||||
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
|
||||
|
||||
- name: Upload visual diff summary
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-summary-${{ matrix.project }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json
|
||||
if-no-files-found: ignore
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload visual diff report artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ./web/output/screenshot-diff/${{ matrix.project }}/
|
||||
if-no-files-found: ignore
|
||||
retention-days: 30
|
||||
|
||||
- name: Update S3 baselines
|
||||
if: >-
|
||||
success() && (
|
||||
github.ref == 'refs/heads/main' ||
|
||||
startsWith(github.ref, 'refs/heads/release/') ||
|
||||
startsWith(github.ref, 'refs/tags/v') ||
|
||||
(
|
||||
github.event_name == 'merge_group' && (
|
||||
github.event.merge_group.base_ref == 'refs/heads/main' ||
|
||||
startsWith(github.event.merge_group.base_ref, 'refs/heads/release/')
|
||||
)
|
||||
)
|
||||
)
|
||||
env:
|
||||
PROJECT: ${{ matrix.project }}
|
||||
PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }}
|
||||
run: |
|
||||
if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then
|
||||
uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \
|
||||
--project "${PROJECT}" \
|
||||
--rev "${BASELINE_REV}" \
|
||||
--delete
|
||||
else
|
||||
echo "No screenshots to upload for ${PROJECT} — skipping baseline update."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
@@ -457,6 +586,95 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
|
||||
- name: Post combined PR comment
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
REPO: ${{ github.repository }}
|
||||
S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }}
|
||||
run: |
|
||||
MARKER="<!-- visual-regression-report -->"
|
||||
|
||||
# Build the markdown table from all summary files
|
||||
TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |"
|
||||
TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|"
|
||||
TABLE_ROWS=""
|
||||
HAS_ANY_SUMMARY=false
|
||||
|
||||
for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do
|
||||
SUMMARY_FILE="${SUMMARY_DIR}summary.json"
|
||||
if [ ! -f "${SUMMARY_FILE}" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
HAS_ANY_SUMMARY=true
|
||||
PROJECT=$(jq -r '.project' "${SUMMARY_FILE}")
|
||||
CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}")
|
||||
ADDED=$(jq -r '.added' "${SUMMARY_FILE}")
|
||||
REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}")
|
||||
UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}")
|
||||
TOTAL=$(jq -r '.total' "${SUMMARY_FILE}")
|
||||
HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}")
|
||||
|
||||
if [ "${TOTAL}" = "0" ]; then
|
||||
REPORT_LINK="_No screenshots_"
|
||||
elif [ "${HAS_DIFF}" = "true" ]; then
|
||||
REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html"
|
||||
REPORT_LINK="[View Report](${REPORT_URL})"
|
||||
else
|
||||
REPORT_LINK="✅ No changes"
|
||||
fi
|
||||
|
||||
TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n"
|
||||
done
|
||||
|
||||
if [ "${HAS_ANY_SUMMARY}" = "false" ]; then
|
||||
echo "No visual diff summaries found — skipping PR comment."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
BODY=$(printf '%s\n' \
|
||||
"${MARKER}" \
|
||||
"### 🖼️ Visual Regression Report" \
|
||||
"" \
|
||||
"${TABLE_HEADER}" \
|
||||
"${TABLE_DIVIDER}" \
|
||||
"$(printf '%b' "${TABLE_ROWS}")")
|
||||
|
||||
# Upsert: find existing comment with the marker, or create a new one
|
||||
EXISTING_COMMENT_ID=$(gh api \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
--jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \
|
||||
2>/dev/null | head -1)
|
||||
|
||||
if [ -n "${EXISTING_COMMENT_ID}" ]; then
|
||||
gh api \
|
||||
--method PATCH \
|
||||
"repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \
|
||||
-f body="${BODY}"
|
||||
else
|
||||
gh api \
|
||||
--method POST \
|
||||
"repos/${REPO}/issues/${PR_NUMBER}/comments" \
|
||||
-f body="${BODY}"
|
||||
fi
|
||||
|
||||
playwright-required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
@@ -467,48 +685,3 @@ jobs:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
|
||||
73
.github/workflows/preview.yml
vendored
Normal file
73
.github/workflows/preview.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: Preview Deployment
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }}
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
on:
|
||||
push:
|
||||
branches-ignore:
|
||||
- main
|
||||
paths:
|
||||
- "web/**"
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
jobs:
|
||||
Deploy-Preview:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Pull Vercel Environment Information
|
||||
run: npx --yes ${{ env.VERCEL_CLI }} pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
- name: Build Project Artifacts
|
||||
run: npx --yes ${{ env.VERCEL_CLI }} build --token=${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
- name: Deploy Project Artifacts to Vercel
|
||||
id: deploy
|
||||
run: |
|
||||
DEPLOYMENT_URL=$(npx --yes ${{ env.VERCEL_CLI }} deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }})
|
||||
echo "url=$DEPLOYMENT_URL" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Update PR comment with deployment URL
|
||||
if: always() && steps.deploy.outputs.url
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
DEPLOYMENT_URL: ${{ steps.deploy.outputs.url }}
|
||||
run: |
|
||||
# Find the PR for this branch
|
||||
PR_NUMBER=$(gh pr list --head "$GITHUB_REF_NAME" --json number --jq '.[0].number')
|
||||
if [ -z "$PR_NUMBER" ]; then
|
||||
echo "No open PR found for branch $GITHUB_REF_NAME, skipping comment."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
COMMENT_MARKER="<!-- preview-deployment -->"
|
||||
COMMENT_BODY="$COMMENT_MARKER
|
||||
**Preview Deployment**
|
||||
|
||||
| Status | Preview | Commit | Updated |
|
||||
| --- | --- | --- | --- |
|
||||
| ✅ | $DEPLOYMENT_URL | \`${GITHUB_SHA::7}\` | $(date -u '+%Y-%m-%d %H:%M:%S UTC') |"
|
||||
|
||||
# Find existing comment by marker
|
||||
EXISTING_COMMENT_ID=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \
|
||||
--jq ".[] | select(.body | startswith(\"$COMMENT_MARKER\")) | .id" | head -1)
|
||||
|
||||
if [ -n "$EXISTING_COMMENT_ID" ]; then
|
||||
gh api "repos/$GITHUB_REPOSITORY/issues/comments/$EXISTING_COMMENT_ID" \
|
||||
--method PATCH --field body="$COMMENT_BODY"
|
||||
else
|
||||
gh pr comment "$PR_NUMBER" --body "$COMMENT_BODY"
|
||||
fi
|
||||
290
.github/workflows/sandbox-deployment.yml
vendored
Normal file
290
.github/workflows/sandbox-deployment.yml
vendored
Normal file
@@ -0,0 +1,290 @@
|
||||
name: Build and Push Sandbox Image on Tag
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "experimental-cc4a.*"
|
||||
|
||||
# Restrictive defaults; jobs declare what they need.
|
||||
permissions: {}
|
||||
|
||||
jobs:
|
||||
check-sandbox-changes:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: read
|
||||
outputs:
|
||||
sandbox-changed: ${{ steps.check.outputs.sandbox-changed }}
|
||||
new-version: ${{ steps.version.outputs.new-version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check for sandbox-relevant file changes
|
||||
id: check
|
||||
run: |
|
||||
# Get the previous tag to diff against
|
||||
CURRENT_TAG="${GITHUB_REF_NAME}"
|
||||
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1)
|
||||
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
echo "No previous experimental-cc4a tag found, building unconditionally"
|
||||
echo "sandbox-changed=true" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}"
|
||||
|
||||
# Check if any sandbox-relevant files changed
|
||||
SANDBOX_PATHS=(
|
||||
"backend/onyx/server/features/build/sandbox/"
|
||||
)
|
||||
|
||||
CHANGED=false
|
||||
for path in "${SANDBOX_PATHS[@]}"; do
|
||||
if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then
|
||||
echo "Changes detected in: $path"
|
||||
CHANGED=true
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Determine new sandbox version
|
||||
id: version
|
||||
if: steps.check.outputs.sandbox-changed == 'true'
|
||||
run: |
|
||||
# Query Docker Hub for the latest versioned tag
|
||||
LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \
|
||||
| jq -r '.results[].name' \
|
||||
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
|
||||
| sort -V \
|
||||
| tail -n 1)
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No existing version tags found on Docker Hub, starting at 0.1.1"
|
||||
NEW_VERSION="0.1.1"
|
||||
else
|
||||
CURRENT_VERSION="${LATEST_TAG#v}"
|
||||
echo "Latest version on Docker Hub: $CURRENT_VERSION"
|
||||
|
||||
# Increment patch version
|
||||
MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2)
|
||||
PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3)
|
||||
NEW_PATCH=$((PATCH + 1))
|
||||
NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}"
|
||||
fi
|
||||
|
||||
echo "New version: $NEW_VERSION"
|
||||
echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT"
|
||||
|
||||
build-sandbox-amd64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-sandbox-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
build-sandbox-arm64:
|
||||
needs: check-sandbox-changes
|
||||
if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-sandbox-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-sandbox:
|
||||
needs:
|
||||
- check-sandbox-changes
|
||||
- build-sandbox-amd64
|
||||
- build-sandbox-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-sandbox
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 30
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/sandbox
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }}
|
||||
type=raw,value=latest
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -6,6 +6,7 @@
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
!/.cursor/mcp.json
|
||||
|
||||
# macos
|
||||
.DS_store
|
||||
|
||||
6
.vscode/launch.json
vendored
6
.vscode/launch.json
vendored
@@ -246,7 +246,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -275,7 +275,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete"
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -419,7 +419,7 @@
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
"-Q",
|
||||
"connector_doc_fetching,user_files_indexing"
|
||||
"connector_doc_fetching"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
|
||||
@@ -144,6 +144,10 @@ function.
|
||||
If you make any updates to a celery worker and you want to test these changes, you will need
|
||||
to ask me to restart the celery worker. There is no auto-restart on code-change mechanism.
|
||||
|
||||
**Task Time Limits**:
|
||||
Since all tasks are executed in thread pools, the time limit features of Celery are silently
|
||||
disabled and won't work. Timeout logic must be implemented within the task itself.
|
||||
|
||||
### Code Quality
|
||||
|
||||
```bash
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
|
||||
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./keys /app/keys
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -474,7 +474,7 @@ def run_migrations_online() -> None:
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
logger.info("run_migrations_online starting (pytest-alembic mode).")
|
||||
logger.debug("run_migrations_online starting (pytest-alembic mode).")
|
||||
|
||||
# For pytest-alembic, we use the default schema (public)
|
||||
schema_name = context.config.attributes.get(
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add default_app_mode to user
|
||||
|
||||
Revision ID: 114a638452db
|
||||
Revises: feead2911109
|
||||
Create Date: 2026-02-09 18:57:08.274640
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "114a638452db"
|
||||
down_revision = "feead2911109"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"default_app_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="CHAT",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "default_app_mode")
|
||||
@@ -11,7 +11,6 @@ import sqlalchemy as sa
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from httpx import HTTPStatusError
|
||||
import httpx
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
@@ -519,15 +518,11 @@ def delete_document_from_db(current_doc_id: str, index_name: str) -> None:
|
||||
def upgrade() -> None:
|
||||
if SKIP_CANON_DRIVE_IDS:
|
||||
return
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
future_search_settings,
|
||||
)
|
||||
current_search_settings, _ = active_search_settings()
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Migrate to contextual rag model
|
||||
|
||||
Revision ID: 19c0ccb01687
|
||||
Revises: 9c54986124c6
|
||||
Create Date: 2026-02-12 11:21:41.798037
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "19c0ccb01687"
|
||||
down_revision = "9c54986124c6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
|
||||
# when the table was created with only CHAT/VISION values.
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=20),
|
||||
existing_type=sa.String(length=10),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
# For every search_settings row that has contextual rag configured,
|
||||
# create an llm_model_flow entry. is_default is TRUE if the row
|
||||
# belongs to the PRESENT search settings, FALSE otherwise.
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
|
||||
SELECT DISTINCT
|
||||
'CONTEXTUAL_RAG',
|
||||
mc.id,
|
||||
(ss.status = 'PRESENT')
|
||||
FROM search_settings ss
|
||||
JOIN llm_provider lp
|
||||
ON lp.name = ss.contextual_rag_llm_provider
|
||||
JOIN model_configuration mc
|
||||
ON mc.llm_provider_id = lp.id
|
||||
AND mc.name = ss.contextual_rag_llm_name
|
||||
WHERE ss.enable_contextual_rag = TRUE
|
||||
AND ss.contextual_rag_llm_name IS NOT NULL
|
||||
AND ss.contextual_rag_llm_provider IS NOT NULL
|
||||
ON CONFLICT (llm_model_flow_type, model_configuration_id)
|
||||
DO UPDATE SET is_default = EXCLUDED.is_default
|
||||
WHERE EXCLUDED.is_default = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM llm_model_flow
|
||||
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"llm_model_flow",
|
||||
"llm_model_flow_type",
|
||||
type_=sa.String(length=10),
|
||||
existing_type=sa.String(length=20),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -16,7 +16,6 @@ from typing import Generator
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -126,14 +125,11 @@ def remove_old_tags() -> None:
|
||||
the document got reindexed, the old tag would not be removed.
|
||||
This function removes those old tags by comparing it against the tags in vespa.
|
||||
"""
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings, future_search_settings
|
||||
)
|
||||
current_search_settings, _ = active_search_settings()
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
if hasattr(current_search_settings, "index_name"):
|
||||
index_name = current_search_settings.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
"""add chunk error and vespa count columns to opensearch tenant migration
|
||||
|
||||
Revision ID: 93c15d6a6fbb
|
||||
Revises: d3fd499c829c
|
||||
Create Date: 2026-02-11 23:07:34.576725
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "93c15d6a6fbb"
|
||||
down_revision = "d3fd499c829c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_errored",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa")
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored")
|
||||
124
backend/alembic/versions/9c54986124c6_add_scim_tables.py
Normal file
124
backend/alembic/versions/9c54986124c6_add_scim_tables.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""add_scim_tables
|
||||
|
||||
Revision ID: 9c54986124c6
|
||||
Revises: b51c6844d1df
|
||||
Create Date: 2026-02-12 20:29:47.448614
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c54986124c6"
|
||||
down_revision = "b51c6844d1df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scim_token",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("hashed_token", sa.String(length=64), nullable=False),
|
||||
sa.Column("token_display", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_by_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_active",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("hashed_token"),
|
||||
)
|
||||
op.create_table(
|
||||
"scim_group_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_group_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
"scim_group_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_table(
|
||||
"scim_user_mapping",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_id", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
onupdate=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
"scim_user_mapping",
|
||||
["external_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
op.f("ix_scim_user_mapping_external_id"),
|
||||
table_name="scim_user_mapping",
|
||||
)
|
||||
op.drop_table("scim_user_mapping")
|
||||
op.drop_index(
|
||||
op.f("ix_scim_group_mapping_external_id"),
|
||||
table_name="scim_group_mapping",
|
||||
)
|
||||
op.drop_table("scim_group_mapping")
|
||||
op.drop_table("scim_token")
|
||||
81
backend/alembic/versions/b51c6844d1df_seed_memory_tool.py
Normal file
81
backend/alembic/versions/b51c6844d1df_seed_memory_tool.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""seed_memory_tool and add enable_memory_tool to user
|
||||
|
||||
Revision ID: b51c6844d1df
|
||||
Revises: 93c15d6a6fbb
|
||||
Create Date: 2026-02-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b51c6844d1df"
|
||||
down_revision = "93c15d6a6fbb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
MEMORY_TOOL = {
|
||||
"name": "MemoryTool",
|
||||
"display_name": "Add Memory",
|
||||
"description": "Save memories about the user for future conversations.",
|
||||
"in_code_tool_id": "MemoryTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
existing = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id"
|
||||
),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
MEMORY_TOOL,
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"enable_memory_tool",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.true(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "enable_memory_tool")
|
||||
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
102
backend/alembic/versions/d3fd499c829c_add_file_reader_tool.py
Normal file
102
backend/alembic/versions/d3fd499c829c_add_file_reader_tool.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""add_file_reader_tool
|
||||
|
||||
Revision ID: d3fd499c829c
|
||||
Revises: 114a638452db
|
||||
Create Date: 2026-02-07 19:28:22.452337
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d3fd499c829c"
|
||||
down_revision = "114a638452db"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
FILE_READER_TOOL = {
|
||||
"name": "read_file",
|
||||
"display_name": "File Reader",
|
||||
"description": (
|
||||
"Read sections of user-uploaded files by character offset. "
|
||||
"Useful for inspecting large files that cannot fit entirely in context."
|
||||
),
|
||||
"in_code_tool_id": "FileReaderTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = existing[0]
|
||||
else:
|
||||
# Insert new tool
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
FILE_READER_TOOL,
|
||||
)
|
||||
tool_id = result.scalar_one()
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"]
|
||||
|
||||
# Remove persona associations first (FK constraint)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE tool_id IN (
|
||||
SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": in_code_tool_id},
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add_opensearch_tenant_migration_columns
|
||||
|
||||
Revision ID: feead2911109
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-10 17:46:34.029937
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "feead2911109"
|
||||
down_revision = "175ea04c7087"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column("vespa_visit_continuation_token", sa.Text(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"total_chunks_migrated",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"migration_completed_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"enable_opensearch_retrieval",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "enable_opensearch_retrieval")
|
||||
op.drop_column("opensearch_tenant_migration_record", "migration_completed_at")
|
||||
op.drop_column("opensearch_tenant_migration_record", "created_at")
|
||||
op.drop_column("opensearch_tenant_migration_record", "total_chunks_migrated")
|
||||
op.drop_column(
|
||||
"opensearch_tenant_migration_record", "vespa_visit_continuation_token"
|
||||
)
|
||||
@@ -1,20 +1,20 @@
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.heavy import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.light import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.monitoring import celery_app
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.primary import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cloud",
|
||||
"ee.onyx.background.celery.tasks.ttl_management",
|
||||
"ee.onyx.background.celery.tasks.usage_reporting",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -536,7 +536,9 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
callback = PermissionSyncCallback(
|
||||
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
|
||||
)
|
||||
|
||||
# pass in the capability to fetch all existing docs for the cc_pair
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
@@ -576,6 +578,13 @@ def connector_permission_sync_generator_task(
|
||||
tasks_generated = 0
|
||||
docs_with_errors = 0
|
||||
for doc_external_access in document_external_accesses:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
f"Permission sync task timed out or stop signal detected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
result = redis_connector.permissions.update_db(
|
||||
lock=lock,
|
||||
new_permissions=[doc_external_access],
|
||||
@@ -932,6 +941,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
@@ -944,11 +954,26 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"PermissionSyncCallback - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
|
||||
@@ -466,6 +466,7 @@ def connector_external_group_sync_generator_task(
|
||||
def _perform_external_group_sync(
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
timeout_seconds: int = JOB_TIMEOUT,
|
||||
) -> None:
|
||||
# Create attempt record at the start
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -518,9 +519,23 @@ def _perform_external_group_sync(
|
||||
seen_users: set[str] = set() # Track unique users across all groups
|
||||
total_groups_processed = 0
|
||||
total_group_memberships_synced = 0
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
|
||||
for external_user_group in external_user_group_generator:
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
elapsed = time.monotonic() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
raise RuntimeError(
|
||||
f"External group sync task timed out: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"elapsed={elapsed:.0f}s "
|
||||
f"timeout={timeout_seconds}s "
|
||||
f"groups_processed={total_groups_processed}"
|
||||
)
|
||||
|
||||
external_user_group_batch.append(external_user_group)
|
||||
|
||||
# Track progress
|
||||
|
||||
@@ -65,21 +65,7 @@ def github_doc_sync(
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
repos = github_connector.fetch_configured_repos()
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -46,16 +43,11 @@ def sharepoint_group_sync(
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
ctx = connector._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
@@ -77,7 +77,7 @@ def stream_search_query(
|
||||
# Get document index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
|
||||
@@ -109,7 +109,9 @@ async def _make_billing_request(
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT, follow_redirects=True
|
||||
) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
|
||||
@@ -27,6 +27,8 @@ class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.server.utils_vector_db import require_vector_db
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -66,7 +67,13 @@ def search_flow_classification(
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
@router.post(
|
||||
"/send-search-message",
|
||||
response_model=None,
|
||||
dependencies=[Depends(require_vector_db)],
|
||||
)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(current_user),
|
||||
|
||||
0
backend/ee/onyx/server/scim/__init__.py
Normal file
0
backend/ee/onyx/server/scim/__init__.py
Normal file
96
backend/ee/onyx/server/scim/filtering.py
Normal file
96
backend/ee/onyx/server/scim/filtering.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
|
||||
|
||||
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
|
||||
resources before deciding whether to create or update them. For example, when
|
||||
an admin assigns a user to the Onyx app, the IdP first checks whether that
|
||||
user already exists::
|
||||
|
||||
GET /scim/v2/Users?filter=userName eq "john@example.com"
|
||||
|
||||
If zero results come back the IdP creates the user (``POST``); if a match is
|
||||
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
|
||||
The same pattern applies to groups (``displayName eq "Engineering"``).
|
||||
|
||||
This module parses the subset of the SCIM filter grammar that identity
|
||||
providers actually send in practice:
|
||||
|
||||
attribute SP operator SP value
|
||||
|
||||
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
|
||||
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
|
||||
the parser returns ``None`` and the caller falls back to an unfiltered list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ScimFilterOperator(str, Enum):
|
||||
"""Supported SCIM filter operators."""
|
||||
|
||||
EQUAL = "eq"
|
||||
CONTAINS = "co"
|
||||
STARTS_WITH = "sw"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ScimFilter:
|
||||
"""Parsed SCIM filter expression."""
|
||||
|
||||
attribute: str
|
||||
operator: ScimFilterOperator
|
||||
value: str
|
||||
|
||||
|
||||
# Matches: attribute operator "value" (with or without quotes around value)
|
||||
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
|
||||
_FILTER_RE = re.compile(
|
||||
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
|
||||
r'(?:"([^"]*)"' # quoted value
|
||||
r"|'([^']*)')" # or single-quoted value
|
||||
r"$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
|
||||
"""Parse a simple SCIM filter expression.
|
||||
|
||||
Args:
|
||||
filter_string: Raw filter query parameter value, e.g.
|
||||
``'userName eq "john@example.com"'``
|
||||
|
||||
Returns:
|
||||
A ``ScimFilter`` if the expression is valid and uses a supported
|
||||
operator, or ``None`` if the input is empty / missing.
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter string is present but malformed or uses
|
||||
an unsupported operator.
|
||||
"""
|
||||
if not filter_string or not filter_string.strip():
|
||||
return None
|
||||
|
||||
match = _FILTER_RE.match(filter_string.strip())
|
||||
if not match:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
|
||||
|
||||
return _build_filter(match, filter_string)
|
||||
|
||||
|
||||
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
|
||||
"""Extract fields from a regex match and construct a ScimFilter."""
|
||||
attribute = match.group(1)
|
||||
op_str = match.group(2).lower()
|
||||
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
|
||||
value = match.group(3) if match.group(3) is not None else match.group(4)
|
||||
|
||||
if value is None:
|
||||
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
|
||||
|
||||
operator = ScimFilterOperator(op_str)
|
||||
|
||||
return ScimFilter(attribute=attribute, operator=operator, value=value)
|
||||
255
backend/ee/onyx/server/scim/models.py
Normal file
255
backend/ee/onyx/server/scim/models.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
|
||||
|
||||
SCIM protocol schemas follow the wire format defined in:
|
||||
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
|
||||
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
|
||||
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Schema URIs (RFC 7643 §8)
|
||||
# Every SCIM JSON payload includes a "schemas" array identifying its type.
|
||||
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
|
||||
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
|
||||
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SCIM Protocol Schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimName(BaseModel):
|
||||
"""User name components (RFC 7643 §4.1.1)."""
|
||||
|
||||
givenName: str | None = None
|
||||
familyName: str | None = None
|
||||
formatted: str | None = None
|
||||
|
||||
|
||||
class ScimEmail(BaseModel):
|
||||
"""Email sub-attribute (RFC 7643 §4.1.2)."""
|
||||
|
||||
value: str
|
||||
type: str | None = None
|
||||
primary: bool = False
|
||||
|
||||
|
||||
class ScimMeta(BaseModel):
|
||||
"""Resource metadata (RFC 7643 §3.1)."""
|
||||
|
||||
resourceType: str | None = None
|
||||
created: datetime | None = None
|
||||
lastModified: datetime | None = None
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
This is the JSON shape that IdPs send when creating/updating a user via
|
||||
SCIM, and the shape we return in GET responses. Field names use camelCase
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
"""Group member reference (RFC 7643 §4.2).
|
||||
|
||||
Represents a user within a SCIM group. The IdP sends these when adding
|
||||
or removing users from groups. ``value`` is the Onyx user ID.
|
||||
"""
|
||||
|
||||
value: str # User ID of the group member
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimGroupResource(BaseModel):
|
||||
"""SCIM Group resource representation (RFC 7643 §4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
|
||||
id: str | None = None
|
||||
externalId: str | None = None
|
||||
displayName: str
|
||||
members: list[ScimGroupMember] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
class ScimListResponse(BaseModel):
|
||||
"""Paginated list response (RFC 7644 §3.4.2)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
|
||||
totalResults: int
|
||||
startIndex: int = 1
|
||||
itemsPerPage: int = 100
|
||||
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimPatchOperationType(str, Enum):
|
||||
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
|
||||
|
||||
ADD = "add"
|
||||
REPLACE = "replace"
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
IdPs use PATCH to make incremental changes — e.g. deactivating a user
|
||||
(replace active=false) or adding/removing group members — instead of
|
||||
replacing the entire resource with PUT.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
|
||||
Operations: list[ScimPatchOperation]
|
||||
|
||||
|
||||
class ScimError(BaseModel):
|
||||
"""SCIM error response (RFC 7644 §3.12)."""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
|
||||
status: str
|
||||
detail: str | None = None
|
||||
scimType: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Provider Configuration (RFC 7643 §5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimSupported(BaseModel):
|
||||
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
|
||||
|
||||
supported: bool
|
||||
|
||||
|
||||
class ScimFilterConfig(BaseModel):
|
||||
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
|
||||
|
||||
supported: bool
|
||||
maxResults: int = 100
|
||||
|
||||
|
||||
class ScimServiceProviderConfig(BaseModel):
|
||||
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
|
||||
|
||||
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
|
||||
initial setup to discover which SCIM features our server supports
|
||||
(e.g. PATCH yes, bulk no, filtering yes).
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(
|
||||
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
|
||||
)
|
||||
patch: ScimSupported = ScimSupported(supported=True)
|
||||
bulk: ScimSupported = ScimSupported(supported=False)
|
||||
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
|
||||
changePassword: ScimSupported = ScimSupported(supported=False)
|
||||
sort: ScimSupported = ScimSupported(supported=False)
|
||||
etag: ScimSupported = ScimSupported(supported=False)
|
||||
authenticationSchemes: list[dict[str, str]] = Field(
|
||||
default_factory=lambda: [
|
||||
{
|
||||
"type": "oauthbearertoken",
|
||||
"name": "OAuth Bearer Token",
|
||||
"description": "Authentication scheme using a SCIM bearer token",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
|
||||
|
||||
class ScimResourceType(BaseModel):
|
||||
"""SCIM ResourceType resource (RFC 7643 §6).
|
||||
|
||||
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
endpoint: str
|
||||
description: str | None = None
|
||||
schema_: str = Field(alias="schema")
|
||||
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Admin API Schemas (Onyx-internal, for SCIM token management)
|
||||
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
|
||||
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ScimTokenCreate(BaseModel):
|
||||
"""Request to create a new SCIM bearer token."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ScimTokenResponse(BaseModel):
|
||||
"""SCIM token metadata returned in list/get responses."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
token_display: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
"""Response returned when a new SCIM token is created.
|
||||
|
||||
Includes the raw token value which is only available at creation time.
|
||||
"""
|
||||
|
||||
raw_token: str
|
||||
256
backend/ee/onyx/server/scim/patch.py
Normal file
256
backend/ee/onyx/server/scim/patch.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
|
||||
|
||||
Identity providers use PATCH to make incremental changes to SCIM resources
|
||||
instead of replacing the entire resource with PUT. Common operations include:
|
||||
|
||||
- Deactivating a user: ``replace`` ``active`` with ``false``
|
||||
- Adding group members: ``add`` to ``members``
|
||||
- Removing group members: ``remove`` from ``members[value eq "..."]``
|
||||
|
||||
This module applies PATCH operations to Pydantic SCIM resource objects and
|
||||
returns the modified result. It does NOT touch the database — the caller is
|
||||
responsible for persisting changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
|
||||
def __init__(self, detail: str, status: int = 400) -> None:
|
||||
self.detail = detail
|
||||
self.status = status
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
current_members: list[dict] = list(data.get("members") or [])
|
||||
added_ids: list[str] = []
|
||||
removed_ids: list[str] = []
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_group_remove(op, current_members, removed_ids)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on Group resource"
|
||||
)
|
||||
|
||||
data["members"] = current_members
|
||||
group = ScimGroupResource.model_validate(data)
|
||||
return group, added_ids, removed_ids
|
||||
|
||||
|
||||
def _apply_group_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
removed_ids.extend(old_ids - new_ids)
|
||||
added_ids.extend(new_ids - old_ids)
|
||||
|
||||
current_members[:] = value
|
||||
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
added_ids: list[str],
|
||||
) -> None:
|
||||
"""Add members to a group."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if path and path != "members":
|
||||
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
|
||||
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
added_ids.append(member_id)
|
||||
existing_ids.add(member_id)
|
||||
|
||||
|
||||
def _apply_group_remove(
|
||||
op: ScimPatchOperation,
|
||||
members: list[dict],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Remove members from a group."""
|
||||
if not op.path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
match = _MEMBER_FILTER_RE.match(op.path)
|
||||
if not match:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported remove path '{op.path}'. "
|
||||
'Expected: members[value eq "user-id"]'
|
||||
)
|
||||
|
||||
target_id = match.group(1)
|
||||
original_len = len(members)
|
||||
members[:] = [m for m in members if m.get("value") != target_id]
|
||||
|
||||
if len(members) < original_len:
|
||||
removed_ids.append(target_id)
|
||||
@@ -4,6 +4,7 @@ from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -89,7 +90,11 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license = community edition, disable EE features
|
||||
# No license found.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
@@ -177,7 +177,7 @@ async def forward_to_control_plane(
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
|
||||
@@ -12,12 +12,14 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -45,6 +47,23 @@ def list_user_groups(
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
|
||||
@@ -76,6 +76,18 @@ class UserGroup(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
)
|
||||
|
||||
|
||||
class UserGroupCreate(BaseModel):
|
||||
name: str
|
||||
user_ids: list[UUID]
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
@@ -41,8 +43,21 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
captcha_token: str | None = None
|
||||
|
||||
@override
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
@@ -60,6 +60,7 @@ from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
@@ -110,6 +111,7 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -272,6 +274,22 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
"""Raise HTTPException(402) if adding users would exceed the seat limit.
|
||||
|
||||
No-op for multi-tenant or CE deployments.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -401,6 +419,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
existing = get_user_by_email(user_create.email, sync_db)
|
||||
if existing is None:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
@@ -610,6 +634,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
@@ -1431,6 +1459,7 @@ def get_anonymous_user() -> User:
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
@@ -525,6 +526,12 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
logger.info(
|
||||
"DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check."
|
||||
)
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
@@ -566,3 +573,31 @@ class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
def get_bootsteps() -> list[type]:
|
||||
return [LivenessProbe]
|
||||
|
||||
|
||||
# Task modules that require a vector DB (Vespa/OpenSearch).
|
||||
# When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists.
|
||||
_VECTOR_DB_TASK_MODULES: set[str] = {
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
# EE modules that are vector-DB-dependent
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
}
|
||||
# NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set
|
||||
# above. It contains celery_beat_heartbeat (which only writes to Redis) alongside
|
||||
# document cleanup tasks. The cleanup tasks won't be invoked in minimal mode
|
||||
# because the periodic tasks that trigger them are in other filtered modules.
|
||||
|
||||
|
||||
def filter_task_modules(modules: list[str]) -> list[str]:
|
||||
"""Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True."""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return modules
|
||||
return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES]
|
||||
|
||||
@@ -118,23 +118,25 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -96,7 +96,9 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -107,7 +107,9 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -96,10 +96,12 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -110,13 +110,16 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -94,7 +94,9 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -314,17 +314,18 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -107,7 +107,9 @@ for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,17 +5,19 @@ from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.connector_runner import CheckpointOutputWrapper
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -31,6 +33,54 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
def _checkpointed_batched_doc_ids(
|
||||
connector: CheckpointedConnector[CT],
|
||||
start: float,
|
||||
end: float,
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
"""Loop through all checkpoint steps and yield batched document IDs.
|
||||
|
||||
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
|
||||
checkpoint call may only initialize internal state without yielding
|
||||
any documents. This function loops until checkpoint.has_more is False
|
||||
to ensure all document IDs are collected across every step.
|
||||
"""
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
while True:
|
||||
checkpoint_output = connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
|
||||
batch: set[str] = set()
|
||||
for document, _hierarchy_node, failure, next_checkpoint in wrapper(
|
||||
checkpoint_output
|
||||
):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure
|
||||
and failure.failed_document
|
||||
and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = set()
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
if not checkpoint.has_more:
|
||||
break
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: (
|
||||
@@ -80,12 +130,8 @@ def extract_ids_from_runnable_connector(
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
doc_batch_id_generator = _checkpointed_batched_doc_ids(
|
||||
runnable_connector, start, end, PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
@@ -215,36 +216,39 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-documents-for-opensearch-migration",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
"task": OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
},
|
||||
}
|
||||
)
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-documents-from-vespa-to-opensearch",
|
||||
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
|
||||
# value and the task is maximally busy, we can expect to see some
|
||||
# enqueued tasks be revoked over time. This is ok; by erring on the
|
||||
# side of "there will probably always be at least one task of this
|
||||
# type in the queue", we are minimizing this task's idleness while
|
||||
# still giving chances for other tasks to execute.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB.
|
||||
_VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-indexing",
|
||||
"check-for-connector-deletion",
|
||||
"check-for-vespa-sync",
|
||||
"check-for-pruning",
|
||||
"check-for-hierarchy-fetching",
|
||||
"check-for-checkpoint-cleanup",
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
beat_task_templates = [
|
||||
t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES
|
||||
]
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
@@ -37,6 +37,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
timeout_seconds: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
@@ -51,11 +52,29 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
self.start_monotonic = time.monotonic()
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
# Check if the associated indexing attempt has been cancelled
|
||||
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
|
||||
return bool(self.redis_connector.stop.fenced)
|
||||
if bool(self.redis_connector.stop.fenced):
|
||||
return True
|
||||
|
||||
# Check if the task has exceeded its timeout
|
||||
# NOTE: Celery's soft_time_limit does not work with thread pools,
|
||||
# so we must enforce timeouts internally.
|
||||
if self.timeout_seconds is not None:
|
||||
elapsed = time.monotonic() - self.start_monotonic
|
||||
if elapsed > self.timeout_seconds:
|
||||
logger.warning(
|
||||
f"IndexingCallback Docprocessing - task timeout exceeded: "
|
||||
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
|
||||
f"cc_pair={self.redis_connector.cc_pair_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
|
||||
"""Amount isn't used yet."""
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Celery tasks for hierarchy fetching."""
|
||||
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
check_for_hierarchy_fetching,
|
||||
)
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
connector_hierarchy_fetching_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]
|
||||
@@ -146,14 +146,26 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
|
||||
"""Collect metrics about queue lengths for different Celery queues"""
|
||||
metrics = []
|
||||
queue_mappings = {
|
||||
"celery_queue_length": "celery",
|
||||
"docprocessing_queue_length": "docprocessing",
|
||||
"sync_queue_length": "sync",
|
||||
"deletion_queue_length": "deletion",
|
||||
"pruning_queue_length": "pruning",
|
||||
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
|
||||
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
|
||||
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
|
||||
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
|
||||
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
|
||||
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
|
||||
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
|
||||
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
|
||||
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
|
||||
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
|
||||
}
|
||||
|
||||
for name, queue in queue_mappings.items():
|
||||
@@ -881,7 +893,7 @@ def monitor_celery_queues_helper(
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
|
||||
n_docfetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
@@ -908,6 +920,26 @@ def monitor_celery_queues_helper(
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
n_hierarchy_fetching = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
|
||||
)
|
||||
n_llm_model_update = celery_get_queue_length(
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
|
||||
)
|
||||
n_checkpoint_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
|
||||
)
|
||||
n_index_attempt_cleanup = celery_get_queue_length(
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
|
||||
)
|
||||
n_csv_generation = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CSV_GENERATION, r_celery
|
||||
)
|
||||
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
|
||||
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
|
||||
n_opensearch_migration = celery_get_queue_length(
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
|
||||
)
|
||||
|
||||
n_docfetching_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
@@ -931,6 +963,14 @@ def monitor_celery_queues_helper(
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
f"hierarchy_fetching={n_hierarchy_fetching} "
|
||||
f"llm_model_update={n_llm_model_update} "
|
||||
f"checkpoint_cleanup={n_checkpoint_cleanup} "
|
||||
f"index_attempt_cleanup={n_index_attempt_cleanup} "
|
||||
f"csv_generation={n_csv_generation} "
|
||||
f"monitoring={n_monitoring} "
|
||||
f"sandbox={n_sandbox} "
|
||||
f"opensearch_migration={n_opensearch_migration} "
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,27 +2,12 @@
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -42,225 +27,32 @@ from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.opensearch_migration import create_opensearch_migration_records_with_commit
|
||||
from onyx.db.opensearch_migration import get_last_opensearch_migration_document_id
|
||||
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
|
||||
from onyx.db.opensearch_migration import get_vespa_visit_state
|
||||
from onyx.db.opensearch_migration import (
|
||||
get_opensearch_migration_records_needing_migration,
|
||||
mark_migration_completed_time_if_not_set_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import get_paginated_document_batch
|
||||
from onyx.db.opensearch_migration import (
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import (
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import should_document_migration_be_permanently_failed
|
||||
from onyx.db.opensearch_migration import (
|
||||
try_insert_opensearch_tenant_migration_record_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import update_vespa_visit_progress_with_commit
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
def _migrate_single_document(
|
||||
document_id: str,
|
||||
opensearch_document_index: OpenSearchDocumentIndex,
|
||||
vespa_document_index: VespaDocumentIndex,
|
||||
tenant_state: TenantState,
|
||||
) -> int:
|
||||
"""Migrates a single document from Vespa to OpenSearch.
|
||||
|
||||
Args:
|
||||
document_id: The ID of the document to migrate.
|
||||
opensearch_document_index: The OpenSearch document index to use.
|
||||
vespa_document_index: The Vespa document index to use.
|
||||
tenant_state: The tenant state to use.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no chunks are found for the document in Vespa, or if
|
||||
the number of candidate chunks to migrate does not match the number
|
||||
of chunks in Vespa.
|
||||
|
||||
Returns:
|
||||
The number of chunks migrated.
|
||||
"""
|
||||
vespa_document_chunks: list[dict[str, Any]] = (
|
||||
vespa_document_index.get_raw_document_chunks(document_id=document_id)
|
||||
)
|
||||
if not vespa_document_chunks:
|
||||
raise RuntimeError(f"No chunks found for document {document_id} in Vespa.")
|
||||
|
||||
opensearch_document_chunks: list[DocumentChunk] = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
vespa_document_chunks, tenant_state, document_id
|
||||
)
|
||||
)
|
||||
if len(opensearch_document_chunks) != len(vespa_document_chunks):
|
||||
raise RuntimeError(
|
||||
f"Bug: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does not match "
|
||||
f"number of chunks in Vespa ({len(vespa_document_chunks)})."
|
||||
)
|
||||
|
||||
opensearch_document_index.index_raw_chunks(chunks=opensearch_document_chunks)
|
||||
|
||||
return len(opensearch_document_chunks)
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
soft_time_limit=CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
time_limit=CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
def check_for_documents_for_opensearch_migration_task(
|
||||
self: Task, *, tenant_id: str # noqa: ARG001
|
||||
) -> bool | None:
|
||||
"""
|
||||
Periodic task to check for and add documents to the OpenSearch migration
|
||||
table.
|
||||
|
||||
Should not execute meaningful logic at the same time as
|
||||
migrate_documents_from_vespa_to_opensearch_task.
|
||||
|
||||
Effectively tries to populate as many migration records as possible within
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of
|
||||
1000 documents.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task failed.
|
||||
"""
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping check for documents for the OpenSearch migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Checking for documents for OpenSearch migration.")
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# migrate_documents_from_vespa_to_opensearch_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
# released after this time.
|
||||
timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
# .acquire will block until the lock is acquired.
|
||||
blocking=True,
|
||||
# Time to wait to acquire the lock.
|
||||
blocking_timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
if not lock.acquire():
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration check task timed out waiting for the lock."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Acquired the OpenSearch migration check lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
num_documents_found_for_record_creation = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
err_str = (
|
||||
f"Tenant ID mismatch in the OpenSearch migration check task: "
|
||||
f"{tenant_id} != {get_current_tenant_id()}. This should never happen."
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
while (
|
||||
time.monotonic() - task_start_time
|
||||
< CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# For pagination, get the last ID we've inserted into
|
||||
# OpenSearchMigration.
|
||||
last_opensearch_migration_document_id = (
|
||||
get_last_opensearch_migration_document_id(db_session)
|
||||
)
|
||||
# Now get the next batch of doc IDs starting after the last ID.
|
||||
# We'll do 1000 documents per transaction/timeout check.
|
||||
document_ids = get_paginated_document_batch(
|
||||
db_session,
|
||||
limit=1000,
|
||||
prev_ending_document_id=last_opensearch_migration_document_id,
|
||||
)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No more documents to insert for OpenSearch migration."
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and the
|
||||
# number of documents matches the number of migration
|
||||
# records, we can be done with this task and update
|
||||
# document_migration_record_table_population_status.
|
||||
return True
|
||||
|
||||
# Create the migration records for the next batch of documents
|
||||
# with status PENDING.
|
||||
create_opensearch_migration_records_with_commit(
|
||||
db_session, document_ids
|
||||
)
|
||||
num_documents_found_for_record_creation += len(document_ids)
|
||||
|
||||
# Try to create the singleton row in
|
||||
# OpenSearchTenantMigrationRecord if it doesn't already exist.
|
||||
# This is a reasonable place to put it because we already have a
|
||||
# lock, a session, and error handling, at the cost of running
|
||||
# this small set of logic for every batch.
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
except Exception:
|
||||
task_logger.exception("Error in the OpenSearch migration check task.")
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the check task."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Finished checking for documents for OpenSearch migration. Found {num_documents_found_for_record_creation} documents "
|
||||
f"to create migration records for in {time.monotonic() - task_start_time:.3f} seconds. However, this may include "
|
||||
"documents for which there already exist records."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
name=OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
@@ -274,18 +66,21 @@ def check_for_documents_for_opensearch_migration_task(
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
def migrate_documents_from_vespa_to_opensearch_task(
|
||||
def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
self: Task, # noqa: ARG001
|
||||
*,
|
||||
tenant_id: str,
|
||||
) -> bool | None:
|
||||
"""Periodic task to migrate documents from Vespa to OpenSearch.
|
||||
"""
|
||||
Periodic task to migrate chunks from Vespa to OpenSearch via the Visit API.
|
||||
|
||||
Should not execute meaningful logic at the same time as
|
||||
check_for_documents_for_opensearch_migration_task.
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
|
||||
Effectively tries to migrate as many documents as possible within
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of 5 documents.
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -294,16 +89,13 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
"""
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping trying to migrate documents from Vespa to OpenSearch."
|
||||
"OpenSearch migration is not enabled, skipping chunk migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Trying a migration batch from Vespa to OpenSearch.")
|
||||
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# check_for_documents_for_opensearch_migration_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
@@ -325,9 +117,8 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
num_documents_migrated = 0
|
||||
num_chunks_migrated = 0
|
||||
num_documents_failed = 0
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
@@ -337,97 +128,100 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# We'll do 5 documents per transaction/timeout check.
|
||||
records_needing_migration = (
|
||||
get_opensearch_migration_records_needing_migration(
|
||||
db_session, limit=5
|
||||
)
|
||||
)
|
||||
if not records_needing_migration:
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
# We reconstruct this mapping for every task invocation because a
|
||||
# document may have been added in the time between two tasks.
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
task_logger.info(
|
||||
"No documents found that need to be migrated from Vespa to OpenSearch."
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
db_session
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token: {continuation_token}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token=continuation_token,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and
|
||||
# document_migration_record_table_population_status is done, we
|
||||
# can be done with this task and update
|
||||
# overall_document_migration_status accordingly. Note that this
|
||||
# includes marking connectors as needing reindexing if some
|
||||
# migrations failed.
|
||||
return True
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=tenant_id, multitenant=MULTI_TENANT
|
||||
)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
)
|
||||
|
||||
for record in records_needing_migration:
|
||||
try:
|
||||
# If the Document's chunk count is not known, it was
|
||||
# probably just indexed so fail here to give it a chance to
|
||||
# sync. If in the rare event this Document has not been
|
||||
# re-indexed in a very long time and is still under the
|
||||
# "old" embedding/indexing logic where chunk count was never
|
||||
# stored, we will eventually permanently fail and thus force
|
||||
# a re-index of this doc, which is a desireable outcome.
|
||||
if record.document.chunk_count is None:
|
||||
raise RuntimeError(
|
||||
f"Document {record.document_id} has no chunk count."
|
||||
)
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
tenant_state,
|
||||
sanitized_to_original_doc_id_mapping,
|
||||
)
|
||||
)
|
||||
if len(opensearch_document_chunks) != len(raw_vespa_chunks):
|
||||
task_logger.error(
|
||||
f"Migration task error: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does "
|
||||
f"not match number of chunks in Vespa ({len(raw_vespa_chunks)}). {len(errored_chunks)} chunks "
|
||||
"errored."
|
||||
)
|
||||
|
||||
chunks_migrated = _migrate_single_document(
|
||||
document_id=record.document_id,
|
||||
opensearch_document_index=opensearch_document_index,
|
||||
vespa_document_index=vespa_document_index,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
opensearch_document_index.index_raw_chunks(
|
||||
chunks=opensearch_document_chunks
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Indexed {len(opensearch_document_chunks)} chunks into OpenSearch in "
|
||||
f"{time.monotonic() - index_opensearch_chunks_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# If the number of chunks in Vespa is not in sync with the
|
||||
# Document table for this doc let's not consider this
|
||||
# completed and let's let a subsequent run take care of it.
|
||||
if chunks_migrated != record.document.chunk_count:
|
||||
raise RuntimeError(
|
||||
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks "
|
||||
f"in Vespa ({record.document.chunk_count}) for document {record.document_id}."
|
||||
)
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token=next_continuation_token,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
)
|
||||
|
||||
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
|
||||
num_documents_migrated += 1
|
||||
num_chunks_migrated += chunks_migrated
|
||||
except Exception:
|
||||
record.status = OpenSearchDocumentMigrationStatus.FAILED
|
||||
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
|
||||
task_logger.exception(
|
||||
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
|
||||
)
|
||||
num_documents_failed += 1
|
||||
finally:
|
||||
record.attempts_count += 1
|
||||
record.last_attempt_at = datetime.now(timezone.utc)
|
||||
if should_document_migration_be_permanently_failed(record):
|
||||
record.status = (
|
||||
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
)
|
||||
# TODO(andrei): Not necessarily here but if this happens
|
||||
# we'll need to mark the connector as needing reindex.
|
||||
|
||||
db_session.commit()
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
return False
|
||||
finally:
|
||||
@@ -439,9 +233,11 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Finished a migration batch from Vespa to OpenSearch. Migrated {num_chunks_migrated} chunks "
|
||||
f"from {num_documents_migrated} documents in {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Failed to migrate {num_documents_failed} documents."
|
||||
f"OpenSearch chunk migration task pausing (time limit reached). "
|
||||
f"Total chunks migrated this task: {total_chunks_migrated_this_task}. "
|
||||
f"Total chunks errored this task: {total_chunks_errored_this_task}. "
|
||||
f"Elapsed: {time.monotonic() - task_start_time:.3f}s. "
|
||||
"Will resume from continuation token on next invocation."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -140,9 +141,7 @@ def _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_acl: dict[str, int] | None,
|
||||
) -> tuple[bool, list[str]]:
|
||||
if not vespa_acl:
|
||||
raise ValueError(
|
||||
"Missing ACL in Vespa chunk. This does not make sense as it implies the document is never searchable by anyone ever."
|
||||
)
|
||||
return False, []
|
||||
acl_list = list(vespa_acl.keys())
|
||||
is_public = PUBLIC_DOC_PAT in acl_list
|
||||
if is_public:
|
||||
@@ -153,133 +152,163 @@ def _transform_vespa_acl_to_opensearch_acl(
|
||||
def transform_vespa_chunks_to_opensearch_chunks(
|
||||
vespa_chunks: list[dict[str, Any]],
|
||||
tenant_state: TenantState,
|
||||
document_id: str,
|
||||
) -> list[DocumentChunk]:
|
||||
sanitized_to_original_doc_id_mapping: dict[str, str],
|
||||
) -> tuple[list[DocumentChunk], list[dict[str, Any]]]:
|
||||
result: list[DocumentChunk] = []
|
||||
errored_chunks: list[dict[str, Any]] = []
|
||||
for vespa_chunk in vespa_chunks:
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using replace_invalid_doc_id_characters.
|
||||
# This was a poor design choice and we don't want this in OpenSearch;
|
||||
# whatever restrictions there may be on indexed chunk ID should have no
|
||||
# bearing on the chunk's document ID field, even if document ID is an
|
||||
# argument to the chunk ID. Deliberately choose to use the real doc ID
|
||||
# supplied to this function.
|
||||
if vespa_document_id != document_id:
|
||||
logger.warning(
|
||||
f"Vespa document ID {vespa_document_id} does not match the document ID supplied {document_id}. "
|
||||
"The Vespa ID will be discarded."
|
||||
try:
|
||||
# This should exist; fail loudly if it does not.
|
||||
vespa_document_id: str = vespa_chunk[DOCUMENT_ID]
|
||||
if not vespa_document_id:
|
||||
raise ValueError("Missing document_id in Vespa chunk.")
|
||||
# Vespa doc IDs were sanitized using
|
||||
# replace_invalid_doc_id_characters. This was a poor design choice
|
||||
# and we don't want this in OpenSearch; whatever restrictions there
|
||||
# may be on indexed chunk ID should have no bearing on the chunk's
|
||||
# document ID field, even if document ID is an argument to the chunk
|
||||
# ID. Deliberately choose to use the real doc ID supplied to this
|
||||
# function.
|
||||
if vespa_document_id in sanitized_to_original_doc_id_mapping:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa document ID {vespa_document_id} does not match the document ID supplied "
|
||||
f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. "
|
||||
"The Vespa ID will be discarded."
|
||||
)
|
||||
document_id = sanitized_to_original_doc_id_mapping.get(
|
||||
vespa_document_id, vespa_document_id
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
# This should exist; fail loudly if it does not.
|
||||
chunk_index: int = vespa_chunk[CHUNK_ID]
|
||||
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
raise ValueError("Missing content in Vespa chunk.")
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa client
|
||||
# in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(vespa_chunk[EMBEDDINGS])
|
||||
if not content_vector:
|
||||
raise ValueError("Missing content_vector in Vespa chunk.")
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError("Missing source_type in Vespa chunk.")
|
||||
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError("Missing semantic_identifier in Vespa chunk.")
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
title: str | None = vespa_chunk.get(TITLE)
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
title_vector: list[float] | None = _extract_title_vector(
|
||||
vespa_chunk.get(TITLE_EMBEDDING)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
# This should exist; fail loudly if it does not; this function will
|
||||
# raise in that event.
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
# This should exist; fail loudly if it does not.
|
||||
content: str = vespa_chunk[CONTENT]
|
||||
if not content:
|
||||
raise ValueError(
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
f"Missing content in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
)
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
# This should exist; fail loudly if it does not.
|
||||
# WARNING: Should supply format.tensors=short-value to the Vespa
|
||||
# client in order to get a supported format for the tensors.
|
||||
content_vector: list[float] = _extract_content_vector(
|
||||
vespa_chunk[EMBEDDINGS]
|
||||
)
|
||||
if not content_vector:
|
||||
raise ValueError(
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
f"Missing content_vector in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
)
|
||||
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
# This should exist; fail loudly if it does not.
|
||||
source_type: str = vespa_chunk[SOURCE_TYPE]
|
||||
if not source_type:
|
||||
raise ValueError(
|
||||
f"Missing source_type in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}."
|
||||
)
|
||||
|
||||
result.append(opensearch_chunk)
|
||||
metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST)
|
||||
|
||||
return result
|
||||
_raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT)
|
||||
last_updated: datetime | None = (
|
||||
datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc)
|
||||
if _raw_doc_updated_at is not None
|
||||
else None
|
||||
)
|
||||
|
||||
hidden: bool = vespa_chunk.get(HIDDEN, False)
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
global_boost: int = vespa_chunk[BOOST]
|
||||
|
||||
# This should exist; fail loudly if it does not.
|
||||
semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER]
|
||||
if not semantic_identifier:
|
||||
raise ValueError(
|
||||
f"Missing semantic_identifier in Vespa chunk with document ID {vespa_document_id} and chunk "
|
||||
f"index {chunk_index}."
|
||||
)
|
||||
|
||||
image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME)
|
||||
source_links: str | None = vespa_chunk.get(SOURCE_LINKS)
|
||||
blurb: str = vespa_chunk.get(BLURB, "")
|
||||
doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "")
|
||||
chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "")
|
||||
metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX)
|
||||
document_sets: list[str] | None = (
|
||||
_transform_vespa_document_sets_to_opensearch_document_sets(
|
||||
vespa_chunk.get(DOCUMENT_SETS)
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
is_public, acl_list = _transform_vespa_acl_to_opensearch_acl(
|
||||
vespa_chunk.get(ACCESS_CONTROL_LIST)
|
||||
)
|
||||
if not is_public and not acl_list:
|
||||
logger.warning(
|
||||
f"Migration warning: Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index} has no "
|
||||
"public ACL and no access control list. This does not make sense as it implies the document is never "
|
||||
"searchable. Continuing with the migration..."
|
||||
)
|
||||
|
||||
chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID)
|
||||
if MULTI_TENANT:
|
||||
if not chunk_tenant_id:
|
||||
raise ValueError(
|
||||
"Missing tenant_id in Vespa chunk in a multi-tenant environment."
|
||||
)
|
||||
if chunk_tenant_id != tenant_state.tenant_id:
|
||||
raise ValueError(
|
||||
f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}"
|
||||
)
|
||||
|
||||
opensearch_chunk = DocumentChunk(
|
||||
# We deliberately choose to use the doc ID supplied to this function
|
||||
# over the Vespa doc ID.
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type=source_type,
|
||||
metadata_list=metadata_list,
|
||||
last_updated=last_updated,
|
||||
public=is_public,
|
||||
access_control_list=acl_list,
|
||||
hidden=hidden,
|
||||
global_boost=global_boost,
|
||||
semantic_identifier=semantic_identifier,
|
||||
image_file_id=image_file_id,
|
||||
source_links=source_links,
|
||||
blurb=blurb,
|
||||
doc_summary=doc_summary,
|
||||
chunk_context=chunk_context,
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
result.append(opensearch_chunk)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.exception(
|
||||
f"Migration error: Error transforming Vespa chunk with document ID {vespa_chunk.get(DOCUMENT_ID)} "
|
||||
f"and chunk index {vespa_chunk.get(CHUNK_ID)} into an OpenSearch chunk. Continuing with "
|
||||
"the migration..."
|
||||
)
|
||||
errored_chunks.append(vespa_chunk)
|
||||
|
||||
return result, errored_chunks
|
||||
|
||||
8
backend/onyx/background/celery/tasks/pruning/__init__.py
Normal file
8
backend/onyx/background/celery/tasks/pruning/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Celery tasks for connector pruning."""
|
||||
|
||||
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
|
||||
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
|
||||
connector_pruning_generator_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]
|
||||
@@ -523,6 +523,7 @@ def connector_pruning_generator_task(
|
||||
redis_connector,
|
||||
lock,
|
||||
r,
|
||||
timeout_seconds=JOB_TIMEOUT,
|
||||
)
|
||||
|
||||
# a list of docs in the source
|
||||
|
||||
@@ -10,10 +10,12 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
@@ -37,6 +39,7 @@ from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -163,6 +166,132 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def _process_user_file_without_vector_db(
|
||||
uf: UserFile,
|
||||
documents: list[Document],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file when the vector DB is disabled.
|
||||
|
||||
Extracts raw text and computes a token count, stores the plaintext in
|
||||
the file store, and marks the file as COMPLETED. Skips embedding and
|
||||
the indexing pipeline entirely.
|
||||
"""
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.factory import get_llm_tokenizer_encode_func
|
||||
|
||||
# Combine section text from all document sections
|
||||
combined_text = " ".join(
|
||||
section.text for doc in documents for section in doc.sections if section.text
|
||||
)
|
||||
|
||||
# Compute token count using the user's default LLM tokenizer
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
encode = get_llm_tokenizer_encode_func(llm)
|
||||
token_count: int | None = len(encode(combined_text))
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Failed to compute token count for {uf.id}, falling back to None"
|
||||
)
|
||||
token_count = None
|
||||
|
||||
# Persist plaintext for fast FileReaderTool loads
|
||||
store_user_file_plaintext(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content=combined_text,
|
||||
)
|
||||
|
||||
# Update the DB record
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.COMPLETED
|
||||
uf.token_count = token_count
|
||||
uf.chunk_count = 0 # no chunks without vector DB
|
||||
uf.last_project_sync_at = datetime.datetime.now(datetime.timezone.utc)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_without_vector_db - "
|
||||
f"Completed id={uf.id} tokens={token_count}"
|
||||
)
|
||||
|
||||
|
||||
def _process_user_file_with_indexing(
|
||||
uf: UserFile,
|
||||
user_file_id: str,
|
||||
documents: list[Document],
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Process a user file through the full indexing pipeline (vector DB path)."""
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
current_search_settings = next(
|
||||
(ss for ss in search_settings_list if ss.status.is_current()),
|
||||
None,
|
||||
)
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"_process_user_file_with_indexing - "
|
||||
f"Indexing pipeline failed id={user_file_id}"
|
||||
)
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
@@ -205,97 +334,34 @@ def process_single_user_file(
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
file_names=[uf.name] if uf.name else None,
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({})
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
current_search_settings = next(
|
||||
(
|
||||
search_settings_instance
|
||||
for search_settings_instance in search_settings_list
|
||||
if search_settings_instance.status.is_current()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if current_search_settings is None:
|
||||
raise RuntimeError(
|
||||
f"process_single_user_file - No current search settings found for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Set up indexing pipeline components
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# update the doument id to userfile id in the documents
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=documents,
|
||||
request_id=None,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Indexing pipeline completed ={index_pipeline_result}"
|
||||
)
|
||||
|
||||
if (
|
||||
index_pipeline_result.failures
|
||||
or index_pipeline_result.total_docs != len(documents)
|
||||
or index_pipeline_result.total_chunks == 0
|
||||
):
|
||||
task_logger.error(
|
||||
f"process_single_user_file - Indexing pipeline failed id={user_file_id}"
|
||||
if DISABLE_VECTOR_DB:
|
||||
_process_user_file_without_vector_db(
|
||||
uf=uf,
|
||||
documents=documents,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
_process_user_file_with_indexing(
|
||||
uf=uf,
|
||||
user_file_id=user_file_id,
|
||||
documents=documents,
|
||||
tenant_id=tenant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
# Re-fetch to avoid mypy error
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
and current_user_file.status != UserFileStatus.DELETING
|
||||
):
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
@@ -409,28 +475,6 @@ def process_single_user_file_delete(
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -438,22 +482,43 @@ def process_single_user_file_delete(
|
||||
)
|
||||
return None
|
||||
|
||||
# 1) Delete Vespa chunks for the document
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
chunk_count = 0
|
||||
if user_file.chunk_count is None or user_file.chunk_count == 0:
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -565,27 +630,6 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
@@ -593,15 +637,35 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
@@ -677,7 +677,6 @@ def connector_document_extraction(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# cc4a
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
# File system only - write directly to persistent storage,
|
||||
# skip chunking/embedding/Vespa but still track documents in DB
|
||||
@@ -817,17 +816,19 @@ def connector_document_extraction(
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id:
|
||||
source_value = db_connector.source.value
|
||||
app.send_task(
|
||||
OnyxCeleryTask.SANDBOX_FILE_SYNC,
|
||||
kwargs={
|
||||
"user_id": str(creator_id),
|
||||
"tenant_id": tenant_id,
|
||||
"source": source_value,
|
||||
},
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"after indexing complete"
|
||||
f"source={source_value} after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -9,10 +9,8 @@ Summaries are stored as `ChatMessage` records with two key fields:
|
||||
- `parent_message_id` → last message when compression triggered (places summary in the tree)
|
||||
- `last_summarized_message_id` → pointer to an older message up the chain (the cutoff). Messages after this are kept verbatim.
|
||||
|
||||
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred.
|
||||
|
||||
### Timestamp-Based Ordering
|
||||
Messages are filtered by `time_sent` (not ID) so the logic remains intact if IDs are changed to UUIDs in the future.
|
||||
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred. It's only back-pointed to by the
|
||||
branch which it applies to. All of this is necessary because we keep the last few messages verbatim and also to support branching logic.
|
||||
|
||||
### Progressive Summarization
|
||||
Subsequent compressions incorporate the existing summary text + new messages, preventing information loss in very long conversations.
|
||||
@@ -26,10 +24,11 @@ Context window breakdown:
|
||||
- `max_context_tokens` — LLM's total context window
|
||||
- `reserved_tokens` — space for system prompt, tools, files, etc.
|
||||
- Available for chat history = `max_context_tokens - reserved_tokens`
|
||||
Note: If there is a lot of reserved tokens, chat compression may happen fairly frequently which is costly, slow, and leads to a bad user experience. Possible area of future improvement.
|
||||
|
||||
Configurable ratios:
|
||||
- `COMPRESSION_TRIGGER_RATIO` (default 0.75) — compress when chat history exceeds this ratio of available space
|
||||
- `RECENT_MESSAGES_RATIO` (default 0.25) — portion of chat history to keep verbatim when compressing
|
||||
- `RECENT_MESSAGES_RATIO` (default 0.2) — portion of chat history to keep verbatim when compressing
|
||||
|
||||
## Flow
|
||||
|
||||
|
||||
@@ -3,32 +3,26 @@ from collections.abc import Callable
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.chat.models import ChatHistoryResult
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -45,9 +39,6 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -276,70 +267,6 @@ def extract_headers(
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=schema,
|
||||
emitter=get_default_emitter(),
|
||||
),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
|
||||
) -> None:
|
||||
@@ -502,15 +429,22 @@ def convert_chat_history(
|
||||
additional_context: str | None,
|
||||
token_counter: Callable[[str], int],
|
||||
tool_id_to_name_map: dict[int, str],
|
||||
) -> list[ChatMessageSimple]:
|
||||
) -> ChatHistoryResult:
|
||||
"""Convert ChatMessage history to ChatMessageSimple format.
|
||||
|
||||
For user messages: includes attached files (images attached to message, text files as separate messages)
|
||||
For assistant messages with tool calls: creates ONE ASSISTANT message with tool_calls array,
|
||||
followed by N TOOL_CALL_RESPONSE messages (OpenAI parallel tool calling format)
|
||||
For assistant messages without tool calls: creates a simple ASSISTANT message
|
||||
|
||||
Every injected text-file message is tagged with ``file_id`` and its
|
||||
metadata is collected in ``ChatHistoryResult.all_injected_file_metadata``.
|
||||
After context-window truncation, callers compare surviving ``file_id`` tags
|
||||
against this map to discover "forgotten" files and provide their metadata
|
||||
to the FileReaderTool.
|
||||
"""
|
||||
simple_messages: list[ChatMessageSimple] = []
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
|
||||
# Create a mapping of file IDs to loaded files for quick lookup
|
||||
file_map = {str(f.file_id): f for f in files}
|
||||
@@ -539,7 +473,9 @@ def convert_chat_history(
|
||||
# Text files (DOC, PLAIN_TEXT, CSV) are added as separate messages
|
||||
text_files.append(loaded_file)
|
||||
|
||||
# Add text files as separate messages before the user message
|
||||
# Add text files as separate messages before the user message.
|
||||
# Each message is tagged with ``file_id`` so that forgotten files
|
||||
# can be detected after context-window truncation.
|
||||
for text_file in text_files:
|
||||
file_text = text_file.content_text or ""
|
||||
filename = text_file.filename
|
||||
@@ -554,8 +490,14 @@ def convert_chat_history(
|
||||
token_count=text_file.token_count,
|
||||
message_type=MessageType.USER,
|
||||
image_files=None,
|
||||
file_id=text_file.file_id,
|
||||
)
|
||||
)
|
||||
all_injected_file_metadata[text_file.file_id] = FileToolMetadata(
|
||||
file_id=text_file.file_id,
|
||||
filename=filename or "unknown",
|
||||
approx_char_count=len(file_text),
|
||||
)
|
||||
|
||||
# Sum token counts from image files (excluding project image files)
|
||||
image_token_count = (
|
||||
@@ -664,32 +606,41 @@ def convert_chat_history(
|
||||
f"Invalid message type when constructing simple history: {chat_message.message_type}"
|
||||
)
|
||||
|
||||
return simple_messages
|
||||
return ChatHistoryResult(
|
||||
simple_messages=simple_messages,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
|
||||
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
|
||||
"""Get the custom agent prompt from persona or project instructions.
|
||||
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
|
||||
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
|
||||
|
||||
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
|
||||
Priority: persona.system_prompt > chat_session.project.instructions > None
|
||||
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
|
||||
|
||||
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
|
||||
# we never want to return an empty string for a prompt so it's translated into an explicit None.
|
||||
|
||||
Args:
|
||||
persona: The Persona object
|
||||
chat_session: The ChatSession object
|
||||
|
||||
Returns:
|
||||
The custom agent prompt string, or None if neither persona nor project has one
|
||||
The prompt to use for the custom Agent part of the prompt.
|
||||
"""
|
||||
# Not considered a custom agent if it's the default behavior persona
|
||||
if persona.id == DEFAULT_PERSONA_ID:
|
||||
return None
|
||||
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
|
||||
if persona.id != DEFAULT_PERSONA_ID:
|
||||
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
|
||||
if persona.replace_base_system_prompt:
|
||||
return None
|
||||
return persona.system_prompt or None
|
||||
|
||||
if persona.system_prompt:
|
||||
return persona.system_prompt
|
||||
elif chat_session.project and chat_session.project.instructions:
|
||||
# If in a project and using the default Agent, respect the project instructions.
|
||||
if chat_session.project and chat_session.project.instructions:
|
||||
return chat_session.project.instructions
|
||||
else:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:
|
||||
|
||||
@@ -17,20 +17,26 @@ from onyx.configs.chat_configs import COMPRESSION_TRIGGER_RATIO
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_PROMPT
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK
|
||||
from onyx.prompts.compression_prompts import PROGRESSIVE_USER_REMINDER
|
||||
from onyx.prompts.compression_prompts import SUMMARIZATION_CUTOFF_MARKER
|
||||
from onyx.prompts.compression_prompts import SUMMARIZATION_PROMPT
|
||||
from onyx.prompts.compression_prompts import USER_FINAL_REMINDER
|
||||
from onyx.prompts.compression_prompts import USER_REMINDER
|
||||
from onyx.tracing.framework.create import ensure_trace
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Ratio of available context to allocate for recent messages after compression
|
||||
RECENT_MESSAGES_RATIO = 0.25
|
||||
RECENT_MESSAGES_RATIO = 0.2
|
||||
|
||||
|
||||
class CompressionResult(BaseModel):
|
||||
@@ -187,6 +193,11 @@ def get_messages_to_summarize(
|
||||
recent_messages.insert(0, msg)
|
||||
tokens_used += msg_tokens
|
||||
|
||||
# Ensure cutoff is right before a user message by moving any leading
|
||||
# non-user messages from recent_messages to older_messages
|
||||
while recent_messages and recent_messages[0].message_type != MessageType.USER:
|
||||
recent_messages.pop(0)
|
||||
|
||||
# Everything else gets summarized
|
||||
recent_ids = {m.id for m in recent_messages}
|
||||
older_messages = [m for m in messages if m.id not in recent_ids]
|
||||
@@ -196,31 +207,47 @@ def get_messages_to_summarize(
|
||||
)
|
||||
|
||||
|
||||
def format_messages_for_summary(
|
||||
def _build_llm_messages_for_summarization(
|
||||
messages: list[ChatMessage],
|
||||
tool_id_to_name: dict[int, str],
|
||||
) -> str:
|
||||
"""Format messages into a string for the summarization prompt.
|
||||
) -> list[UserMessage | AssistantMessage]:
|
||||
"""Convert ChatMessage objects to LLM message format for summarization.
|
||||
|
||||
Tool call messages are formatted compactly to save tokens.
|
||||
This is intentionally different from translate_history_to_llm_format in llm_step.py:
|
||||
- Compacts tool calls to "[Used tools: tool1, tool2]" to save tokens in summaries
|
||||
- Skips TOOL_CALL_RESPONSE messages entirely (tool usage captured in assistant message)
|
||||
- No image/multimodal handling (summaries are text-only)
|
||||
- No caching or LLMConfig-specific behavior needed
|
||||
"""
|
||||
formatted = []
|
||||
result: list[UserMessage | AssistantMessage] = []
|
||||
|
||||
for msg in messages:
|
||||
# Format assistant messages with tool calls compactly
|
||||
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
|
||||
tool_names = [
|
||||
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
|
||||
]
|
||||
formatted.append(f"[assistant used tools: {', '.join(tool_names)}]")
|
||||
# Skip empty messages
|
||||
if not msg.message:
|
||||
continue
|
||||
|
||||
# Handle assistant messages with tool calls compactly
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
if msg.tool_calls:
|
||||
tool_names = [
|
||||
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
|
||||
]
|
||||
result.append(
|
||||
AssistantMessage(content=f"[Used tools: {', '.join(tool_names)}]")
|
||||
)
|
||||
else:
|
||||
result.append(AssistantMessage(content=msg.message))
|
||||
continue
|
||||
|
||||
# Skip tool call response messages - tool calls are captured above via assistant messages
|
||||
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
continue
|
||||
|
||||
role = msg.message_type.value
|
||||
formatted.append(f"[{role}]: {msg.message}")
|
||||
return "\n\n".join(formatted)
|
||||
# Handle user messages
|
||||
if msg.message_type == MessageType.USER:
|
||||
result.append(UserMessage(content=msg.message))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generate_summary(
|
||||
@@ -236,6 +263,9 @@ def generate_summary(
|
||||
The cutoff marker tells the LLM to summarize only older messages,
|
||||
while using recent messages as context to inform what's important.
|
||||
|
||||
Messages are sent as separate UserMessage/AssistantMessage objects rather
|
||||
than being concatenated into a single message.
|
||||
|
||||
Args:
|
||||
older_messages: Messages to compress into summary (before cutoff)
|
||||
recent_messages: Messages kept verbatim (after cutoff, for context only)
|
||||
@@ -246,37 +276,54 @@ def generate_summary(
|
||||
Returns:
|
||||
Summary text
|
||||
"""
|
||||
older_messages_str = format_messages_for_summary(older_messages, tool_id_to_name)
|
||||
recent_messages_str = format_messages_for_summary(recent_messages, tool_id_to_name)
|
||||
|
||||
# Build user prompt with cutoff marker
|
||||
# Build system prompt
|
||||
system_content = SUMMARIZATION_PROMPT
|
||||
if existing_summary:
|
||||
# Progressive summarization: include existing summary
|
||||
user_prompt = PROGRESSIVE_SUMMARY_PROMPT.format(
|
||||
existing_summary=existing_summary
|
||||
# Progressive summarization: append existing summary to system prompt
|
||||
system_content += PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK.format(
|
||||
previous_summary=existing_summary
|
||||
)
|
||||
user_prompt += f"\n\n{older_messages_str}"
|
||||
final_reminder = PROGRESSIVE_USER_REMINDER
|
||||
else:
|
||||
# Initial summarization
|
||||
user_prompt = older_messages_str
|
||||
final_reminder = USER_FINAL_REMINDER
|
||||
final_reminder = USER_REMINDER
|
||||
|
||||
# Add cutoff marker and recent messages as context
|
||||
user_prompt += f"\n\n{SUMMARIZATION_CUTOFF_MARKER}"
|
||||
if recent_messages_str:
|
||||
user_prompt += f"\n\n{recent_messages_str}"
|
||||
# Convert messages to LLM format (using compression-specific conversion)
|
||||
older_llm_messages = _build_llm_messages_for_summarization(
|
||||
older_messages, tool_id_to_name
|
||||
)
|
||||
recent_llm_messages = _build_llm_messages_for_summarization(
|
||||
recent_messages, tool_id_to_name
|
||||
)
|
||||
|
||||
# Build message list with separate messages
|
||||
input_messages: list[ChatCompletionMessage] = [
|
||||
SystemMessage(content=system_content),
|
||||
]
|
||||
|
||||
# Add older messages (to be summarized)
|
||||
input_messages.extend(older_llm_messages)
|
||||
|
||||
# Add cutoff marker as a user message
|
||||
input_messages.append(UserMessage(content=SUMMARIZATION_CUTOFF_MARKER))
|
||||
|
||||
# Add recent messages (for context only)
|
||||
input_messages.extend(recent_llm_messages)
|
||||
|
||||
# Add final reminder
|
||||
user_prompt += f"\n\n{final_reminder}"
|
||||
input_messages.append(UserMessage(content=final_reminder))
|
||||
|
||||
response = llm.invoke(
|
||||
[
|
||||
SystemMessage(content=SUMMARIZATION_PROMPT),
|
||||
UserMessage(content=user_prompt),
|
||||
]
|
||||
)
|
||||
return response.choice.message.content or ""
|
||||
with llm_generation_span(
|
||||
llm=llm,
|
||||
flow="chat_history_summarization",
|
||||
input_messages=input_messages,
|
||||
) as span_generation:
|
||||
response = llm.invoke(input_messages)
|
||||
record_llm_response(span_generation, response)
|
||||
|
||||
content = response.choice.message.content
|
||||
if not (content and content.strip()):
|
||||
raise ValueError("LLM returned empty summary")
|
||||
return content.strip()
|
||||
|
||||
|
||||
def compress_chat_history(
|
||||
@@ -292,6 +339,19 @@ def compress_chat_history(
|
||||
The summary message's parent_message_id points to the last message in
|
||||
chat_history, making it branch-aware via the tree structure.
|
||||
|
||||
Note: This takes the entire chat history as input, splits it into older
|
||||
messages (to summarize) and recent messages (kept verbatim within the
|
||||
token budget), generates a summary of the older part, and persists the
|
||||
new summary message with its parent set to the last message in history.
|
||||
|
||||
Past summary is taken into context (progressive summarization): we find
|
||||
at most one existing summary for this branch. If present, only messages
|
||||
after that summary's last_summarized_message_id are considered; the
|
||||
existing summary text is passed into the LLM so the new summary
|
||||
incorporates it instead of summarizing from scratch.
|
||||
|
||||
For more details, see the COMPRESSION.md file.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
chat_history: Branch-aware list of messages
|
||||
@@ -305,74 +365,84 @@ def compress_chat_history(
|
||||
if not chat_history:
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
|
||||
chat_session_id = chat_history[0].chat_session_id
|
||||
|
||||
logger.info(
|
||||
f"Starting compression for session {chat_history[0].chat_session_id}, "
|
||||
f"Starting compression for session {chat_session_id}, "
|
||||
f"history_len={len(chat_history)}, tokens_for_recent={compression_params.tokens_for_recent}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Find existing summary for this branch
|
||||
existing_summary = find_summary_for_branch(db_session, chat_history)
|
||||
with ensure_trace(
|
||||
"chat_history_compression",
|
||||
group_id=str(chat_session_id),
|
||||
metadata={
|
||||
"tenant_id": get_current_tenant_id(),
|
||||
"chat_session_id": str(chat_session_id),
|
||||
},
|
||||
):
|
||||
try:
|
||||
# Find existing summary for this branch
|
||||
existing_summary = find_summary_for_branch(db_session, chat_history)
|
||||
|
||||
# Get messages to summarize
|
||||
summary_content = get_messages_to_summarize(
|
||||
chat_history,
|
||||
existing_summary,
|
||||
tokens_for_recent=compression_params.tokens_for_recent,
|
||||
)
|
||||
# Get messages to summarize
|
||||
summary_content = get_messages_to_summarize(
|
||||
chat_history,
|
||||
existing_summary,
|
||||
tokens_for_recent=compression_params.tokens_for_recent,
|
||||
)
|
||||
|
||||
if not summary_content.older_messages:
|
||||
logger.debug("No messages to summarize, skipping compression")
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
if not summary_content.older_messages:
|
||||
logger.debug("No messages to summarize, skipping compression")
|
||||
return CompressionResult(summary_created=False, messages_summarized=0)
|
||||
|
||||
# Generate summary (incorporate existing summary if present)
|
||||
existing_summary_text = existing_summary.message if existing_summary else None
|
||||
summary_text = generate_summary(
|
||||
older_messages=summary_content.older_messages,
|
||||
recent_messages=summary_content.recent_messages,
|
||||
llm=llm,
|
||||
tool_id_to_name=tool_id_to_name,
|
||||
existing_summary=existing_summary_text,
|
||||
)
|
||||
# Generate summary (incorporate existing summary if present)
|
||||
existing_summary_text = (
|
||||
existing_summary.message if existing_summary else None
|
||||
)
|
||||
summary_text = generate_summary(
|
||||
older_messages=summary_content.older_messages,
|
||||
recent_messages=summary_content.recent_messages,
|
||||
llm=llm,
|
||||
tool_id_to_name=tool_id_to_name,
|
||||
existing_summary=existing_summary_text,
|
||||
)
|
||||
|
||||
# Calculate token count for the summary
|
||||
tokenizer = get_tokenizer(None, None)
|
||||
summary_token_count = len(tokenizer.encode(summary_text))
|
||||
logger.debug(
|
||||
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
|
||||
)
|
||||
# Calculate token count for the summary
|
||||
tokenizer = get_tokenizer(None, None)
|
||||
summary_token_count = len(tokenizer.encode(summary_text))
|
||||
logger.debug(
|
||||
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
|
||||
)
|
||||
|
||||
# Create new summary as a ChatMessage
|
||||
# Parent is the last message in history - this makes the summary branch-aware
|
||||
summary_message = ChatMessage(
|
||||
chat_session_id=chat_history[0].chat_session_id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
message=summary_text,
|
||||
token_count=summary_token_count,
|
||||
parent_message_id=chat_history[-1].id,
|
||||
last_summarized_message_id=summary_content.older_messages[-1].id,
|
||||
)
|
||||
db_session.add(summary_message)
|
||||
db_session.commit()
|
||||
# Create new summary as a ChatMessage
|
||||
# Parent is the last message in history - this makes the summary branch-aware
|
||||
summary_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
message=summary_text,
|
||||
token_count=summary_token_count,
|
||||
parent_message_id=chat_history[-1].id,
|
||||
last_summarized_message_id=summary_content.older_messages[-1].id,
|
||||
)
|
||||
db_session.add(summary_message)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Compressed {len(summary_content.older_messages)} messages into summary "
|
||||
f"(session_id={chat_history[0].chat_session_id}, "
|
||||
f"summary_tokens={summary_token_count})"
|
||||
)
|
||||
logger.info(
|
||||
f"Compressed {len(summary_content.older_messages)} messages into summary "
|
||||
f"(session_id={chat_session_id}, "
|
||||
f"summary_tokens={summary_token_count})"
|
||||
)
|
||||
|
||||
return CompressionResult(
|
||||
summary_created=True,
|
||||
messages_summarized=len(summary_content.older_messages),
|
||||
)
|
||||
return CompressionResult(
|
||||
summary_created=True,
|
||||
messages_summarized=len(summary_content.older_messages),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Compression failed for session {chat_history[0].chat_session_id}: {e}"
|
||||
)
|
||||
db_session.rollback()
|
||||
return CompressionResult(
|
||||
summary_created=False,
|
||||
messages_summarized=0,
|
||||
error=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Compression failed for session {chat_session_id}: {e}")
|
||||
db_session.rollback()
|
||||
return CompressionResult(
|
||||
summary_created=False,
|
||||
messages_summarized=0,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -14,6 +16,7 @@ from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
@@ -27,12 +30,14 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -43,12 +48,14 @@ from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import MemoryToolResponseSnapshot
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -60,6 +67,28 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -179,6 +208,35 @@ def _build_project_file_citation_mapping(
|
||||
return citation_mapping
|
||||
|
||||
|
||||
def _build_project_message(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
token_counter: Callable[[str], int] | None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Build messages for project / tool-backed files.
|
||||
|
||||
Returns up to two messages:
|
||||
1. The full-text project files message (if project_file_texts is populated).
|
||||
2. A lightweight metadata message for files the LLM should access via the
|
||||
FileReaderTool (e.g. oversized chat-attached files or project files that
|
||||
don't fit in context).
|
||||
"""
|
||||
if not project_files:
|
||||
return []
|
||||
|
||||
messages: list[ChatMessageSimple] = []
|
||||
if project_files.project_file_texts:
|
||||
messages.append(
|
||||
_create_project_files_message(project_files, token_counter=None)
|
||||
)
|
||||
if project_files.file_metadata_for_tool and token_counter:
|
||||
messages.append(
|
||||
_create_file_tool_metadata_message(
|
||||
project_files.file_metadata_for_tool, token_counter
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def construct_message_history(
|
||||
system_prompt: ChatMessageSimple | None,
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
@@ -187,6 +245,8 @@ def construct_message_history(
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
token_counter: Callable[[str], int] | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
if last_n_user_messages is not None:
|
||||
if last_n_user_messages <= 0:
|
||||
@@ -194,13 +254,17 @@ def construct_message_history(
|
||||
"filtering chat history by last N user messages must be a value greater than 0"
|
||||
)
|
||||
|
||||
# Build the project / file-metadata messages up front so we can use their
|
||||
# actual token counts for the budget.
|
||||
project_messages = _build_project_message(project_files, token_counter)
|
||||
project_messages_tokens = sum(m.token_count for m in project_messages)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
history_token_budget -= system_prompt.token_count if system_prompt else 0
|
||||
history_token_budget -= (
|
||||
custom_agent_prompt.token_count if custom_agent_prompt else 0
|
||||
)
|
||||
if project_files:
|
||||
history_token_budget -= project_files.total_token_count
|
||||
history_token_budget -= project_messages_tokens
|
||||
history_token_budget -= reminder_message.token_count if reminder_message else 0
|
||||
|
||||
if history_token_budget < 0:
|
||||
@@ -214,11 +278,7 @@ def construct_message_history(
|
||||
result = [system_prompt] if system_prompt else []
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
result.append(project_message)
|
||||
result.extend(project_messages)
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
return result
|
||||
@@ -277,8 +337,11 @@ def construct_message_history(
|
||||
# Calculate remaining budget for history before the last user message
|
||||
remaining_budget = history_token_budget - required_tokens
|
||||
|
||||
# Truncate history_before_last_user from the top to fit in remaining budget
|
||||
# Truncate history_before_last_user from the top to fit in remaining budget.
|
||||
# Track dropped file messages so we can provide their metadata to the
|
||||
# FileReaderTool instead.
|
||||
truncated_history_before: list[ChatMessageSimple] = []
|
||||
dropped_file_ids: list[str] = []
|
||||
current_token_count = 0
|
||||
|
||||
for msg in reversed(history_before_last_user):
|
||||
@@ -287,9 +350,67 @@ def construct_message_history(
|
||||
truncated_history_before.insert(0, msg)
|
||||
current_token_count += msg.token_count
|
||||
else:
|
||||
# Can't fit this message, stop truncating
|
||||
# Can't fit this message, stop truncating.
|
||||
# This message and everything older is dropped.
|
||||
break
|
||||
|
||||
# Collect file_ids from ALL dropped messages (those not in
|
||||
# truncated_history_before). The truncation loop above keeps the most
|
||||
# recent messages, so the dropped ones are at the start of the original
|
||||
# list up to (len(history) - len(kept)).
|
||||
num_kept = len(truncated_history_before)
|
||||
for msg in history_before_last_user[: len(history_before_last_user) - num_kept]:
|
||||
if msg.file_id is not None:
|
||||
dropped_file_ids.append(msg.file_id)
|
||||
|
||||
# Also treat "orphaned" metadata entries as dropped -- these are files
|
||||
# from messages removed by summary truncation (before convert_chat_history
|
||||
# ran), so no ChatMessageSimple was ever tagged with their file_id.
|
||||
if all_injected_file_metadata:
|
||||
surviving_file_ids = {
|
||||
msg.file_id for msg in simple_chat_history if msg.file_id is not None
|
||||
}
|
||||
for fid in all_injected_file_metadata:
|
||||
if fid not in surviving_file_ids and fid not in dropped_file_ids:
|
||||
dropped_file_ids.append(fid)
|
||||
|
||||
# Build a forgotten-files metadata message if any file messages were
|
||||
# dropped AND we have metadata for them (meaning the FileReaderTool is
|
||||
# available). Reserve tokens for this message in the budget.
|
||||
forgotten_files_message: ChatMessageSimple | None = None
|
||||
if dropped_file_ids and all_injected_file_metadata and token_counter:
|
||||
forgotten_meta = [
|
||||
all_injected_file_metadata[fid]
|
||||
for fid in dropped_file_ids
|
||||
if fid in all_injected_file_metadata
|
||||
]
|
||||
if forgotten_meta:
|
||||
logger.debug(
|
||||
f"FileReader: building forgotten-files message for "
|
||||
f"{[(m.file_id, m.filename) for m in forgotten_meta]}"
|
||||
)
|
||||
forgotten_files_message = _create_file_tool_metadata_message(
|
||||
forgotten_meta, token_counter
|
||||
)
|
||||
# Shrink the remaining budget. If the metadata message doesn't
|
||||
# fit we may need to drop more history messages.
|
||||
remaining_budget -= forgotten_files_message.token_count
|
||||
while truncated_history_before and current_token_count > remaining_budget:
|
||||
evicted = truncated_history_before.pop(0)
|
||||
current_token_count -= evicted.token_count
|
||||
# If the evicted message is itself a file, add it to the
|
||||
# forgotten metadata (it's now dropped too).
|
||||
if (
|
||||
evicted.file_id is not None
|
||||
and evicted.file_id in all_injected_file_metadata
|
||||
and evicted.file_id not in {m.file_id for m in forgotten_meta}
|
||||
):
|
||||
forgotten_meta.append(all_injected_file_metadata[evicted.file_id])
|
||||
# Rebuild the message with the new entry
|
||||
forgotten_files_message = _create_file_tool_metadata_message(
|
||||
forgotten_meta, token_counter
|
||||
)
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files and project_files.project_image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
@@ -302,7 +423,7 @@ def construct_message_history(
|
||||
|
||||
# Build the final message list according to README ordering:
|
||||
# [system], [history_before_last_user], [custom_agent], [project_files],
|
||||
# [last_user_message], [messages_after_last_user], [reminder]
|
||||
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
|
||||
result = [system_prompt] if system_prompt else []
|
||||
|
||||
# 1. Add truncated history before last user message
|
||||
@@ -312,26 +433,52 @@ def construct_message_history(
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files message (inserted before last user message)
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
result.append(project_message)
|
||||
# 3. Add project files / file-metadata messages (inserted before last user message)
|
||||
result.extend(project_messages)
|
||||
|
||||
# 4. Add last user message (with project images attached)
|
||||
# 4. Add forgotten-files metadata (right before the user's question)
|
||||
if forgotten_files_message:
|
||||
result.append(forgotten_files_message)
|
||||
|
||||
# 5. Add last user message (with project images attached)
|
||||
result.append(last_user_message)
|
||||
|
||||
# 5. Add messages after last user message (tool calls, responses, etc.)
|
||||
# 6. Add messages after last user message (tool calls, responses, etc.)
|
||||
result.extend(messages_after_last_user)
|
||||
|
||||
# 6. Add reminder message at the very end
|
||||
# 7. Add reminder message at the very end
|
||||
if reminder_message:
|
||||
result.append(reminder_message)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _create_file_tool_metadata_message(
|
||||
file_metadata: list[FileToolMetadata],
|
||||
token_counter: Callable[[str], int],
|
||||
) -> ChatMessageSimple:
|
||||
"""Build a lightweight metadata-only message listing files available via FileReaderTool.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM must use ``read_file`` to inspect them.
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
return ChatMessageSimple(
|
||||
message=message_content,
|
||||
token_count=token_counter(message_content),
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_project_files_message(
|
||||
project_files: ExtractedProjectFiles,
|
||||
token_counter: Callable[[str], int] | None, # noqa: ARG001
|
||||
@@ -379,6 +526,8 @@ def run_llm_loop(
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
include_citations: bool = True,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
inject_memories_in_prompt: bool = True,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
@@ -444,6 +593,7 @@ def run_llm_loop(
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
# Handling tool calls based on cycle count and past cycle conditions
|
||||
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
@@ -455,11 +605,17 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
# Handling the system prompt and custom agent prompt
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
@@ -477,18 +633,22 @@ def run_llm_loop(
|
||||
else:
|
||||
# If it's an empty string, we assume the user does not want to include it as an empty System message
|
||||
if default_base_system_prompt:
|
||||
open_ai_formatting_enabled = model_needs_formatting_reenabled(
|
||||
llm.config.model_name
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if inject_memories_in_prompt
|
||||
else (
|
||||
user_memory_context.without_memories()
|
||||
if user_memory_context
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
user_memory_context=user_memory_context,
|
||||
user_memory_context=prompt_memory_context,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
open_ai_formatting_enabled=open_ai_formatting_enabled,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=system_prompt_str,
|
||||
@@ -541,7 +701,7 @@ def run_llm_loop(
|
||||
ChatMessageSimple(
|
||||
message=reminder_message_text,
|
||||
token_count=token_counter(reminder_message_text),
|
||||
message_type=MessageType.USER,
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
)
|
||||
if reminder_message_text
|
||||
else None
|
||||
@@ -554,6 +714,8 @@ def run_llm_loop(
|
||||
reminder_message=reminder_msg,
|
||||
project_files=project_files,
|
||||
available_tokens=available_tokens,
|
||||
token_counter=token_counter,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
@@ -645,6 +807,7 @@ def run_llm_loop(
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
inject_memories_in_prompt=inject_memories_in_prompt,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
@@ -709,11 +872,44 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
# Persist memory if this is a memory tool response
|
||||
memory_snapshot: MemoryToolResponseSnapshot | None = None
|
||||
if isinstance(tool_response.rich_response, MemoryToolResponse):
|
||||
persisted_memory_id: int | None = None
|
||||
if user_memory_context and user_memory_context.user_id:
|
||||
if tool_response.rich_response.index_to_replace is not None:
|
||||
memory = update_memory_at_index(
|
||||
user_id=user_memory_context.user_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
new_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id if memory else None
|
||||
else:
|
||||
memory = add_memory(
|
||||
user_id=user_memory_context.user_id,
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
db_session=db_session,
|
||||
)
|
||||
persisted_memory_id = memory.id
|
||||
operation: Literal["add", "update"] = (
|
||||
"update"
|
||||
if tool_response.rich_response.index_to_replace is not None
|
||||
else "add"
|
||||
)
|
||||
memory_snapshot = MemoryToolResponseSnapshot(
|
||||
memory_text=tool_response.rich_response.memory_text,
|
||||
operation=operation,
|
||||
memory_id=persisted_memory_id,
|
||||
index=tool_response.rich_response.index_to_replace,
|
||||
)
|
||||
|
||||
if memory_snapshot:
|
||||
saved_response = json.dumps(memory_snapshot.model_dump())
|
||||
elif isinstance(tool_response.rich_response, str):
|
||||
saved_response = tool_response.rich_response
|
||||
else:
|
||||
saved_response = tool_response.llm_facing_response
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
|
||||
@@ -36,6 +36,10 @@ from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE
|
||||
from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -332,26 +336,48 @@ def extract_tool_calls_from_response_text(
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
if not matched_tool_call:
|
||||
continue
|
||||
|
||||
# `find_all_json_objects` can return both an outer tool-call object and
|
||||
# its nested arguments object. If both resolve to the same tool call,
|
||||
# drop only this nested duplicate artifact.
|
||||
if (
|
||||
prev_json_obj is not None
|
||||
and prev_tool_call is not None
|
||||
and matched_tool_call == prev_tool_call
|
||||
and _is_nested_arguments_duplicate(
|
||||
previous_json_obj=prev_json_obj,
|
||||
current_json_obj=json_obj,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
)
|
||||
tab_index += 1
|
||||
):
|
||||
continue
|
||||
|
||||
matched_tool_calls.append(matched_tool_call)
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
@@ -433,6 +459,42 @@ def _try_match_json_to_tool(
|
||||
return None
|
||||
|
||||
|
||||
def _is_nested_arguments_duplicate(
|
||||
previous_json_obj: dict[str, Any],
|
||||
current_json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> bool:
|
||||
"""Detect when current object is the nested args object from previous tool call."""
|
||||
extracted_args = _extract_nested_arguments_obj(previous_json_obj, tool_name_to_def)
|
||||
return extracted_args is not None and current_json_obj == extracted_args
|
||||
|
||||
|
||||
def _extract_nested_arguments_obj(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> dict[str, Any] | None:
|
||||
# Format 1: {"name": "...", "arguments": {...}} or {"name": "...", "parameters": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
args_obj = json_obj.get("arguments", json_obj.get("parameters"))
|
||||
if isinstance(args_obj, dict):
|
||||
return args_obj
|
||||
|
||||
# Format 2: {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
function_obj = json_obj["function"]
|
||||
if "name" in function_obj and function_obj["name"] in tool_name_to_def:
|
||||
args_obj = function_obj.get("arguments", function_obj.get("parameters"))
|
||||
if isinstance(args_obj, dict):
|
||||
return args_obj
|
||||
|
||||
# Format 3: {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj and isinstance(json_obj[tool_name], dict):
|
||||
return json_obj[tool_name]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -451,6 +513,7 @@ def translate_history_to_llm_format(
|
||||
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.USER_REMINDER,
|
||||
MessageType.ASSISTANT,
|
||||
MessageType.TOOL_CALL_RESPONSE,
|
||||
]:
|
||||
@@ -512,6 +575,16 @@ def translate_history_to_llm_format(
|
||||
)
|
||||
messages.append(user_msg_text)
|
||||
|
||||
elif msg.message_type == MessageType.USER_REMINDER:
|
||||
# User reminder messages are wrapped with system-reminder tags
|
||||
# and converted to UserMessage (LLM APIs don't have a native reminder type)
|
||||
wrapped_content = f"{SYSTEM_REMINDER_TAG_OPEN}\n{msg.message}\n{SYSTEM_REMINDER_TAG_CLOSE}"
|
||||
reminder_msg = UserMessage(
|
||||
role="user",
|
||||
content=wrapped_content,
|
||||
)
|
||||
messages.append(reminder_msg)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
tool_calls_list: list[ToolCall] | None = None
|
||||
if msg.tool_calls:
|
||||
@@ -552,6 +625,17 @@ def translate_history_to_llm_format(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
# Apply model-specific formatting when translating to LLM format (e.g. OpenAI
|
||||
# reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation)
|
||||
if model_needs_formatting_reenabled(llm_config.model_name):
|
||||
for i, m in enumerate(messages):
|
||||
if isinstance(m, SystemMessage):
|
||||
messages[i] = SystemMessage(
|
||||
role="system",
|
||||
content=CODE_BLOCK_MARKDOWN + m.content,
|
||||
)
|
||||
break
|
||||
|
||||
# prompt caching: rely on should_cache in ChatMessageSimple to
|
||||
# pick the split point for the cacheable prefix and suffix
|
||||
if last_cacheable_msg_idx != -1:
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -20,54 +16,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
|
||||
|
||||
class StreamType(Enum):
|
||||
SUB_QUESTIONS = "sub_questions"
|
||||
SUB_ANSWER = "sub_answer"
|
||||
MAIN_ANSWER = "main_answer"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
stop_reason: StreamStopReason
|
||||
|
||||
stream_type: StreamType = StreamType.MAIN_ANSWER
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
data["stop_reason"] = self.stop_reason.name
|
||||
return data
|
||||
|
||||
|
||||
class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class OnyxAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
@@ -78,23 +26,11 @@ class StreamingError(BaseModel):
|
||||
details: dict | None = None # Additional context (tool name, model name, etc.)
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: ToolResultType
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ProjectSearchConfig(BaseModel):
|
||||
"""Configuration for search tool availability in project context."""
|
||||
|
||||
@@ -102,83 +38,15 @@ class ProjectSearchConfig(BaseModel):
|
||||
disable_forced_tool: bool
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: UUID
|
||||
|
||||
|
||||
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
AnswerStreamPart = (
|
||||
Packet
|
||||
| StreamStopInfo
|
||||
| MessageResponseIDInfo
|
||||
| StreamingError
|
||||
| UserKnowledgeFilePacket
|
||||
| CreateChatSessionID
|
||||
)
|
||||
AnswerStreamPart = Packet | MessageResponseIDInfo | StreamingError | CreateChatSessionID
|
||||
|
||||
AnswerStream = Iterator[AnswerStreamPart]
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
error_msg: str | None
|
||||
message_id: int
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ToolCallResponse(BaseModel):
|
||||
"""Tool call with full details for non-streaming response."""
|
||||
|
||||
@@ -191,8 +59,23 @@ class ToolCallResponse(BaseModel):
|
||||
pre_reasoning: str | None = None
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str
|
||||
answer_citationless: str
|
||||
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
error_msg: str | None
|
||||
message_id: int
|
||||
citation_info: list[CitationInfo]
|
||||
|
||||
|
||||
class ChatFullResponse(BaseModel):
|
||||
"""Complete non-streaming response with all available data."""
|
||||
"""Complete non-streaming response with all available data.
|
||||
NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
|
||||
experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
|
||||
"""
|
||||
|
||||
# Core response fields
|
||||
answer: str
|
||||
@@ -244,6 +127,9 @@ class ChatMessageSimple(BaseModel):
|
||||
# represents the end of the cacheable prefix
|
||||
# used for prompt caching
|
||||
should_cache: bool = False
|
||||
# When this message represents an injected text file, this is the file's ID.
|
||||
# Used to detect which file messages survive context-window truncation.
|
||||
file_id: str | None = None
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
@@ -254,6 +140,33 @@ class ProjectFileMetadata(BaseModel):
|
||||
file_content: str
|
||||
|
||||
|
||||
class FileToolMetadata(BaseModel):
|
||||
"""Lightweight metadata for exposing files to the FileReaderTool.
|
||||
|
||||
Used when files cannot be loaded directly into context (project too large
|
||||
or persona-attached user_files without direct-load path). The LLM receives
|
||||
a listing of these so it knows which files it can read via ``read_file``.
|
||||
"""
|
||||
|
||||
file_id: str
|
||||
filename: str
|
||||
approx_char_count: int
|
||||
|
||||
|
||||
class ChatHistoryResult(BaseModel):
|
||||
"""Result of converting chat history to simple format.
|
||||
|
||||
Bundles the simple messages with metadata for every text file that was
|
||||
injected into the history. After context-window truncation drops older
|
||||
messages, callers compare surviving ``file_id`` tags against this map
|
||||
to discover "forgotten" files whose metadata should be provided to the
|
||||
FileReaderTool.
|
||||
"""
|
||||
|
||||
simple_messages: list[ChatMessageSimple]
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata]
|
||||
|
||||
|
||||
class ExtractedProjectFiles(BaseModel):
|
||||
project_file_texts: list[str]
|
||||
project_image_files: list[ChatLoadedFile]
|
||||
@@ -263,6 +176,9 @@ class ExtractedProjectFiles(BaseModel):
|
||||
project_file_metadata: list[ProjectFileMetadata]
|
||||
# None if not a project
|
||||
project_uncapped_token_count: int | None
|
||||
# Lightweight metadata for files exposed via FileReaderTool
|
||||
# (populated when files don't fit in context and vector DB is disabled)
|
||||
file_metadata_for_tool: list[FileToolMetadata] = []
|
||||
|
||||
|
||||
class LlmStepResult(BaseModel):
|
||||
|
||||
@@ -4,12 +4,12 @@ An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -35,7 +35,7 @@ from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ProjectSearchConfig
|
||||
from onyx.chat.models import StreamingError
|
||||
@@ -44,6 +44,7 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -60,6 +61,7 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
@@ -77,8 +79,7 @@ from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -90,7 +91,11 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import SearchToolUsage
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import FileReaderToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
|
||||
FileReaderTool,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -100,6 +105,53 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class _AvailableFiles(BaseModel):
|
||||
"""Separated file IDs for the FileReaderTool so it knows which loader to use."""
|
||||
|
||||
# IDs from the ``user_file`` table (project / persona-attached files).
|
||||
user_file_ids: list[UUID] = []
|
||||
# IDs from the ``file_record`` table (chat-attached files).
|
||||
chat_file_ids: list[UUID] = []
|
||||
|
||||
|
||||
def _collect_available_file_ids(
|
||||
chat_history: list[ChatMessage],
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> _AvailableFiles:
|
||||
"""Collect all file IDs the FileReaderTool should be allowed to access.
|
||||
|
||||
Returns *separate* lists for chat-attached files (``file_record`` IDs) and
|
||||
project/user files (``user_file`` IDs) so the tool can pick the right
|
||||
loader without a try/except fallback."""
|
||||
chat_file_ids: set[UUID] = set()
|
||||
user_file_ids: set[UUID] = set()
|
||||
|
||||
for msg in chat_history:
|
||||
if not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
try:
|
||||
chat_file_ids.add(UUID(fd["id"]))
|
||||
except (ValueError, KeyError):
|
||||
pass
|
||||
|
||||
if project_id:
|
||||
project_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
for uf in project_files:
|
||||
user_file_ids.add(uf.id)
|
||||
|
||||
return _AvailableFiles(
|
||||
user_file_ids=list(user_file_ids),
|
||||
chat_file_ids=list(chat_file_ids),
|
||||
)
|
||||
|
||||
|
||||
def _should_enable_slack_search(
|
||||
persona: Persona,
|
||||
filters: BaseFilters | None,
|
||||
@@ -232,6 +284,24 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
project_image_files.append(chat_loaded_file)
|
||||
else:
|
||||
if DISABLE_VECTOR_DB:
|
||||
# Without a vector DB we can't use project-as-filter search.
|
||||
# Instead, build lightweight metadata so the LLM can call the
|
||||
# FileReaderTool to inspect individual files on demand.
|
||||
file_metadata_for_tool = _build_file_tool_metadata_for_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ExtractedProjectFiles(
|
||||
project_file_texts=[],
|
||||
project_image_files=[],
|
||||
project_as_filter=False,
|
||||
total_token_count=0,
|
||||
project_file_metadata=[],
|
||||
project_uncapped_token_count=project_tokens,
|
||||
file_metadata_for_tool=file_metadata_for_tool,
|
||||
)
|
||||
project_as_filter = True
|
||||
|
||||
return ExtractedProjectFiles(
|
||||
@@ -244,6 +314,49 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_project(
|
||||
project_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata for every file in a project.
|
||||
|
||||
Used when files are too large to fit in context and the vector DB is
|
||||
disabled, so the LLM needs to know which files it can read via the
|
||||
FileReaderTool.
|
||||
"""
|
||||
project_user_files = get_user_files_from_project(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in project_user_files
|
||||
]
|
||||
|
||||
|
||||
def _build_file_tool_metadata_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> list[FileToolMetadata]:
|
||||
"""Build lightweight FileToolMetadata from a list of UserFile records."""
|
||||
return [
|
||||
FileToolMetadata(
|
||||
file_id=str(uf.id),
|
||||
filename=uf.name,
|
||||
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
|
||||
)
|
||||
for uf in user_files
|
||||
]
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
@@ -317,7 +430,6 @@ def handle_stream_message_objects(
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
tenant_id = get_current_tenant_id()
|
||||
processing_start_time = time.monotonic()
|
||||
mock_response_token: Token[str | None] | None = None
|
||||
|
||||
llm: LLM | None = None
|
||||
@@ -330,12 +442,10 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
raise ValueError(
|
||||
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
|
||||
)
|
||||
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
|
||||
if new_msg_req.mock_llm_response is not None and not INTEGRATION_TESTS_MODE:
|
||||
raise ValueError(
|
||||
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
|
||||
)
|
||||
|
||||
try:
|
||||
if not new_msg_req.chat_session_id:
|
||||
@@ -463,24 +573,68 @@ def handle_stream_message_objects(
|
||||
|
||||
chat_history.append(user_message)
|
||||
|
||||
# Collect file IDs for the file reader tool *before* summary
|
||||
# truncation so that files attached to older (summarized-away)
|
||||
# messages are still accessible via the FileReaderTool.
|
||||
available_files = _collect_available_file_ids(
|
||||
chat_history=chat_history,
|
||||
project_id=chat_session.project_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Find applicable summary for the current branch
|
||||
# Summary applies if its parent_message_id is in current chat_history
|
||||
summary_message = find_summary_for_branch(db_session, chat_history)
|
||||
# Collect file metadata from messages that will be dropped by
|
||||
# summary truncation. These become "pre-summarized" file metadata
|
||||
# so the forgotten-file mechanism can still tell the LLM about them.
|
||||
summarized_file_metadata: dict[str, FileToolMetadata] = {}
|
||||
if summary_message and summary_message.last_summarized_message_id:
|
||||
cutoff_id = summary_message.last_summarized_message_id
|
||||
for msg in chat_history:
|
||||
if msg.id > cutoff_id or not msg.files:
|
||||
continue
|
||||
for fd in msg.files:
|
||||
file_id = fd.get("id")
|
||||
if not file_id:
|
||||
continue
|
||||
summarized_file_metadata[file_id] = FileToolMetadata(
|
||||
file_id=file_id,
|
||||
filename=fd.get("name") or "unknown",
|
||||
# We don't know the exact size without loading the
|
||||
# file, but 0 signals "unknown" to the LLM.
|
||||
approx_char_count=0,
|
||||
)
|
||||
# Filter chat_history to only messages after the cutoff
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
|
||||
# This is the custom prompt which may come from the Agent or Project. We fetch it earlier because the inner loop
|
||||
# (run_llm_loop and run_deep_research_llm_loop) should not need to be aware of the Chat History in the DB form processed
|
||||
# here, however we need this early for token reservation.
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
# When use_memories is disabled, strip memories from the prompt context
|
||||
# but keep user info/preferences. The full context is still passed
|
||||
# to the LLM loop for memory tool persistence.
|
||||
prompt_memory_context = (
|
||||
user_memory_context
|
||||
if user.use_memories
|
||||
else user_memory_context.without_memories()
|
||||
)
|
||||
|
||||
max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + (
|
||||
custom_agent_prompt or ""
|
||||
)
|
||||
|
||||
reserved_token_count = calculate_reserved_tokens(
|
||||
db_session=db_session,
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
persona_system_prompt=max_reserved_system_prompt_tokens_str,
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
user_memory_context=user_memory_context,
|
||||
user_memory_context=prompt_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
@@ -492,6 +646,16 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# When the vector DB is disabled, persona-attached user_files have no
|
||||
# search pipeline path. Inject them as file_metadata_for_tool so the
|
||||
# LLM can read them via the FileReaderTool.
|
||||
if DISABLE_VECTOR_DB and persona.user_files:
|
||||
persona_file_metadata = _build_file_tool_metadata_for_user_files(
|
||||
persona.user_files
|
||||
)
|
||||
# Merge persona file metadata into the extracted project files
|
||||
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
@@ -518,6 +682,13 @@ def handle_stream_message_objects(
|
||||
|
||||
emitter = get_default_emitter()
|
||||
|
||||
# Also grant access to persona-attached user files
|
||||
if persona.user_files:
|
||||
existing = set(available_files.user_file_ids)
|
||||
for uf in persona.user_files:
|
||||
if uf.id not in existing:
|
||||
available_files.user_file_ids.append(uf.id)
|
||||
|
||||
# Construct tools based on the persona configurations
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
@@ -544,6 +715,10 @@ def handle_stream_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
file_reader_tool_config=FileReaderToolConfig(
|
||||
user_file_ids=available_files.user_file_ids,
|
||||
chat_file_ids=available_files.chat_file_ids,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
)
|
||||
@@ -573,9 +748,12 @@ def handle_stream_message_objects(
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Check whether the FileReaderTool is among the constructed tools.
|
||||
has_file_reader_tool = any(isinstance(t, FileReaderTool) for t in tools)
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
simple_chat_history = convert_chat_history(
|
||||
chat_history_result = convert_chat_history(
|
||||
chat_history=chat_history,
|
||||
files=files,
|
||||
project_image_files=extracted_project_files.project_image_files,
|
||||
@@ -583,6 +761,32 @@ def handle_stream_message_objects(
|
||||
token_counter=token_counter,
|
||||
tool_id_to_name_map=tool_id_to_name_map,
|
||||
)
|
||||
simple_chat_history = chat_history_result.simple_messages
|
||||
|
||||
# Metadata for every text file injected into the history. After
|
||||
# context-window truncation drops older messages, the LLM loop
|
||||
# compares surviving file_id tags against this map to discover
|
||||
# "forgotten" files and provide their metadata to FileReaderTool.
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] = (
|
||||
chat_history_result.all_injected_file_metadata
|
||||
if has_file_reader_tool
|
||||
else {}
|
||||
)
|
||||
|
||||
# Merge in file metadata from messages dropped by summary
|
||||
# truncation. These files are no longer in simple_chat_history
|
||||
# so they would otherwise be invisible to the forgotten-file
|
||||
# mechanism. They will always appear as "forgotten" since no
|
||||
# surviving message carries their file_id tag.
|
||||
if summarized_file_metadata:
|
||||
for fid, meta in summarized_file_metadata.items():
|
||||
all_injected_file_metadata.setdefault(fid, meta)
|
||||
|
||||
if all_injected_file_metadata:
|
||||
logger.debug(
|
||||
"FileReader: file metadata for LLM: "
|
||||
f"{[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}"
|
||||
)
|
||||
|
||||
# Prepend summary message if compression exists
|
||||
if summary_message is not None:
|
||||
@@ -623,9 +827,13 @@ def handle_stream_message_objects(
|
||||
assistant_message=assistant_response,
|
||||
llm=llm,
|
||||
reserved_tokens=reserved_token_count,
|
||||
processing_start_time=processing_start_time,
|
||||
)
|
||||
|
||||
# The stream generator can resume on a different worker thread after early yields.
|
||||
# Set this right before launching the LLM loop so run_in_background copies the right context.
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -654,6 +862,7 @@ def handle_stream_message_objects(
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
@@ -675,6 +884,8 @@ def handle_stream_message_objects(
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
include_citations=new_msg_req.include_citations,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
inject_memories_in_prompt=user.use_memories,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -748,7 +959,6 @@ def llm_loop_completion_handle(
|
||||
assistant_message: ChatMessage,
|
||||
llm: LLM,
|
||||
reserved_tokens: int,
|
||||
processing_start_time: float | None = None, # noqa: ARG001
|
||||
) -> None:
|
||||
chat_session_id = assistant_message.chat_session_id
|
||||
|
||||
@@ -811,68 +1021,6 @@ def llm_loop_completion_handle(
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
# messages. Both of the below are used for Slack
|
||||
# NOTE: is not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
# Slack context for federated Slack search
|
||||
slack_context: SlackContext | None = None,
|
||||
) -> AnswerStream:
|
||||
forced_tool_id = (
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
)
|
||||
if (
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
):
|
||||
all_tools = get_tools(db_session)
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
forced_tool_id = search_tool_id
|
||||
|
||||
translated_new_msg_req = SendMessageRequest(
|
||||
message=new_msg_req.message,
|
||||
llm_override=new_msg_req.llm_override,
|
||||
mock_llm_response=new_msg_req.mock_llm_response,
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
forced_tool_id=forced_tool_id,
|
||||
file_descriptors=new_msg_req.file_descriptors,
|
||||
internal_search_filters=(
|
||||
new_msg_req.retrieval_options.filters
|
||||
if new_msg_req.retrieval_options
|
||||
else None
|
||||
),
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
bypass_acl=bypass_acl,
|
||||
additional_context=additional_context,
|
||||
slack_context=slack_context,
|
||||
)
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
|
||||
|
||||
@@ -9,13 +9,13 @@ from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.prompt_utils import replace_reminder_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
@@ -25,7 +25,12 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -131,6 +136,59 @@ def build_reminder_message(
|
||||
return reminder if reminder else None
|
||||
|
||||
|
||||
def _build_user_information_section(
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
company_context: str | None,
|
||||
) -> str:
|
||||
"""Build the complete '# User Information' section with all sub-sections
|
||||
in the correct order: Basic Info → Team Info → Preferences → Memories."""
|
||||
sections: list[str] = []
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip()
|
||||
if ctx.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=ctx.user_info.name or "",
|
||||
user_email=ctx.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if company_context:
|
||||
sections.append(
|
||||
TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip())
|
||||
)
|
||||
|
||||
if user_memory_context:
|
||||
ctx = user_memory_context
|
||||
|
||||
if ctx.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences)
|
||||
)
|
||||
|
||||
if ctx.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
@@ -138,18 +196,12 @@ def build_system_prompt(
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
open_ai_formatting_enabled: bool = False,
|
||||
) -> str:
|
||||
"""Should only be called with the default behavior system prompt.
|
||||
If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function.
|
||||
"""
|
||||
system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware)
|
||||
|
||||
# See https://simonwillison.net/tags/markdown/ for context on why this is needed
|
||||
# for OpenAI reasoning models to have correct markdown generation
|
||||
if open_ai_formatting_enabled:
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
|
||||
# Replace citation guidance placeholder if present
|
||||
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
|
||||
system_prompt,
|
||||
@@ -157,16 +209,14 @@ def build_system_prompt(
|
||||
include_all_guidance=include_all_guidance,
|
||||
)
|
||||
|
||||
# Replace reminder tag placeholder if present
|
||||
system_prompt = replace_reminder_tag(system_prompt)
|
||||
|
||||
company_context = get_company_context()
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
user_info_section = _build_user_information_section(
|
||||
user_memory_context, company_context
|
||||
)
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
system_prompt += user_info_section
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -50,6 +50,17 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
# Controls whether users can use User Knowledge (personal documents) in assistants
|
||||
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
|
||||
|
||||
# Disables vector DB (Vespa/OpenSearch) entirely. When True, connectors and RAG search
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
@@ -75,7 +86,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
@@ -225,11 +236,32 @@ DOCUMENT_INDEX_NAME = "danswer_index"
|
||||
# OpenSearch Configs
|
||||
OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
# TODO(andrei): 60 seconds is too much, we're just setting a high default
|
||||
# timeout for now to examine why queries are slow.
|
||||
# NOTE: This timeout applies to all requests the client makes, including bulk
|
||||
# indexing.
|
||||
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S") or 60
|
||||
)
|
||||
# TODO(andrei): 50 seconds is too much, we're just setting a high default
|
||||
# timeout for now to examine why queries are slow.
|
||||
# NOTE: To get useful partial results, this value should be less than the client
|
||||
# timeout above.
|
||||
DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
# Profiling adds some overhead to OpenSearch operations. This overhead is
|
||||
# unknown right now. It is enabled by default so we can get useful logs for
|
||||
# investigating slow queries. We may never disable it if the overhead is
|
||||
# minimal.
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
@@ -900,6 +932,9 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Limit on number of users a free trial tenant can invite (cloud only)
|
||||
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
|
||||
|
||||
# Security and authentication
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
@@ -942,6 +977,7 @@ API_KEY_HASH_ROUNDS = (
|
||||
# MCP Server Configs
|
||||
#####
|
||||
MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true"
|
||||
MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0")
|
||||
MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090)
|
||||
|
||||
# CORS origins for MCP clients (comma-separated)
|
||||
|
||||
@@ -102,7 +102,6 @@ DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_PENDING_USERS_KEY = "PENDING_USERS"
|
||||
@@ -160,6 +159,8 @@ CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
|
||||
TMP_DRALPHA_PERSONA_NAME = "KG Beta"
|
||||
@@ -226,6 +227,9 @@ class DocumentSource(str, Enum):
|
||||
MOCK_CONNECTOR = "mock_connector"
|
||||
# Special case for user files
|
||||
USER_FILE = "user_file"
|
||||
# Raw files for Craft sandbox access (xlsx, pptx, docx, etc.)
|
||||
# Uses RAW_BINARY processing mode - no text extraction
|
||||
CRAFT_FILE = "craft_file"
|
||||
|
||||
|
||||
class FederatedConnectorSource(str, Enum):
|
||||
@@ -307,6 +311,7 @@ class MessageType(str, Enum):
|
||||
USER = "user" # HumanMessage
|
||||
ASSISTANT = "assistant" # AIMessage - Can include tool_calls field for parallel tool calling
|
||||
TOOL_CALL_RESPONSE = "tool_call_response"
|
||||
USER_REMINDER = "user_reminder" # Custom Onyx message type which is translated into a USER message when passed to the LLM
|
||||
|
||||
|
||||
class ChatMessageSimpleType(str, Enum):
|
||||
@@ -331,6 +336,7 @@ class FileOrigin(str, Enum):
|
||||
CHAT_UPLOAD = "chat_upload"
|
||||
CHAT_IMAGE_GEN = "chat_image_gen"
|
||||
CONNECTOR = "connector"
|
||||
CONNECTOR_METADATA = "connector_metadata"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
INDEXING_CHECKPOINT = "indexing_checkpoint"
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
@@ -396,6 +402,8 @@ class OnyxCeleryQueues:
|
||||
# Sandbox processing queue
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
OPENSEARCH_MIGRATION = "opensearch_migration"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -447,6 +455,9 @@ class OnyxRedisLocks:
|
||||
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
|
||||
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
|
||||
|
||||
# Sandbox file sync
|
||||
SANDBOX_FILE_SYNC_LOCK_PREFIX = "da_lock:sandbox_file_sync"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -577,6 +588,9 @@ class OnyxCeleryTask:
|
||||
MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_documents_from_vespa_to_opensearch_task"
|
||||
)
|
||||
MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_chunks_from_vespa_to_opensearch_task"
|
||||
)
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextvars
|
||||
import re
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@@ -14,6 +15,7 @@ from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
@@ -62,11 +64,44 @@ class AirtableClientNotSetUpError(PermissionError):
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
# Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
# Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ)
|
||||
_AIRTABLE_URL_PATTERN = re.compile(
|
||||
r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?",
|
||||
)
|
||||
|
||||
|
||||
def parse_airtable_url(
|
||||
url: str,
|
||||
) -> tuple[str, str, str | None]:
|
||||
"""Parse an Airtable URL into (base_id, table_id, view_id).
|
||||
|
||||
Accepts URLs like:
|
||||
https://airtable.com/appXXX/tblYYY
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ
|
||||
https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide
|
||||
|
||||
Returns:
|
||||
(base_id, table_id, view_id or None)
|
||||
|
||||
Raises:
|
||||
ValueError if the URL doesn't match the expected format.
|
||||
"""
|
||||
match = _AIRTABLE_URL_PATTERN.search(url.strip())
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Could not parse Airtable URL: '{url}'. "
|
||||
"Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]"
|
||||
)
|
||||
return match.group(1), match.group(2), match.group(3)
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
base_id: str = "",
|
||||
table_name_or_id: str = "",
|
||||
airtable_url: str = "",
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
@@ -75,16 +110,33 @@ class AirtableConnector(LoadConnector):
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
base_id: The ID of the Airtable base (not required when airtable_url is set)
|
||||
table_name_or_id: The name or ID of the table (not required when airtable_url is set)
|
||||
airtable_url: An Airtable URL to parse base_id, table_id, and view_id from.
|
||||
Overrides base_id, table_name_or_id, and view_id if provided.
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
share_id: Optional ID of a "share" to use for generating record URLs
|
||||
batch_size: Number of records to process in each batch
|
||||
|
||||
Mode is auto-detected: if a specific table is identified (via URL or
|
||||
base_id + table_name_or_id), the connector indexes that single table.
|
||||
Otherwise, it discovers and indexes all accessible bases and tables.
|
||||
"""
|
||||
# If a URL is provided, parse it to extract base_id, table_id, and view_id
|
||||
if airtable_url:
|
||||
parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url(
|
||||
airtable_url
|
||||
)
|
||||
base_id = parsed_base_id
|
||||
table_name_or_id = parsed_table_id
|
||||
if parsed_view_id:
|
||||
view_id = parsed_view_id
|
||||
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.index_all = not (base_id and table_name_or_id)
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
@@ -103,6 +155,33 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.index_all:
|
||||
try:
|
||||
bases = self.airtable_client.bases()
|
||||
if not bases:
|
||||
raise ConnectorValidationError(
|
||||
"No bases found. Ensure your API token has access to at least one base."
|
||||
)
|
||||
except ConnectorValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to list Airtable bases: {e}")
|
||||
else:
|
||||
if not self.base_id or not self.table_name_or_id:
|
||||
raise ConnectorValidationError(
|
||||
"A valid Airtable URL or base_id and table_name_or_id are required "
|
||||
"when not using index_all mode."
|
||||
)
|
||||
try:
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table.schema()
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to access table '{self.table_name_or_id}' "
|
||||
f"in base '{self.base_id}': {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
@@ -267,6 +346,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
@@ -291,7 +371,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
base_id=base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -326,15 +406,17 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
base_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
base_id: The ID of the base this record belongs to
|
||||
base_name: The name of the base (used in semantic ID for index_all mode)
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
@@ -367,6 +449,7 @@ class AirtableConnector(LoadConnector):
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
base_id=base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
@@ -379,11 +462,26 @@ class AirtableConnector(LoadConnector):
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
# Include base name in semantic ID only in index_all mode
|
||||
if self.index_all and base_name:
|
||||
semantic_id = (
|
||||
f"{base_name} > {table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else f"{base_name} > {table_name}"
|
||||
)
|
||||
else:
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
# Build hierarchy source_path for Craft file system subdirectory structure.
|
||||
# This creates: airtable/{base_name}/{table_name}/record.json
|
||||
source_path: list[str] = []
|
||||
if base_name:
|
||||
source_path.append(base_name)
|
||||
source_path.append(table_name)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
@@ -391,19 +489,39 @@ class AirtableConnector(LoadConnector):
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": source_path,
|
||||
"base_id": base_id,
|
||||
"table_id": table_id,
|
||||
"table_name": table_name,
|
||||
**({"base_name": base_name} if base_name else {}),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
def _resolve_base_name(self, base_id: str) -> str | None:
|
||||
"""Try to resolve a human-readable base name from the API."""
|
||||
try:
|
||||
for base_info in self.airtable_client.bases():
|
||||
if base_info.id == base_id:
|
||||
return base_info.name
|
||||
except Exception:
|
||||
logger.debug(f"Could not resolve base name for {base_id}")
|
||||
return None
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
def _index_table(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
base_name: str | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Index all records from a single table. Yields batches of Documents."""
|
||||
# Resolve base name for hierarchy if not provided
|
||||
if base_name is None:
|
||||
base_name = self._resolve_base_name(base_id)
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table = self.airtable_client.table(base_id, table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
@@ -415,21 +533,25 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
logger.info(
|
||||
f"Processing {len(records)} records from table "
|
||||
f"'{table_schema.name}' in base '{base_name or base_id}'."
|
||||
)
|
||||
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents = []
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
future_to_record: dict[Future[Document | None], RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
@@ -440,6 +562,8 @@ class AirtableConnector(LoadConnector):
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
base_id=base_id,
|
||||
base_name=base_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
@@ -454,9 +578,58 @@ class AirtableConnector(LoadConnector):
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from one or all tables.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
if self.index_all:
|
||||
yield from self._load_all()
|
||||
else:
|
||||
yield from self._index_table(
|
||||
base_id=self.base_id,
|
||||
table_name_or_id=self.table_name_or_id,
|
||||
)
|
||||
|
||||
def _load_all(self) -> GenerateDocumentsOutput:
|
||||
"""Discover all bases and tables, then index everything."""
|
||||
bases = self.airtable_client.bases()
|
||||
logger.info(f"Discovered {len(bases)} Airtable base(s).")
|
||||
|
||||
for base_info in bases:
|
||||
base_id = base_info.id
|
||||
base_name = base_info.name
|
||||
logger.info(f"Listing tables for base '{base_name}' ({base_id}).")
|
||||
|
||||
try:
|
||||
base = self.airtable_client.base(base_id)
|
||||
tables = base.tables()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to list tables for base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.")
|
||||
|
||||
for table in tables:
|
||||
try:
|
||||
yield from self._index_table(
|
||||
base_id=base_id,
|
||||
table_name_or_id=table.id,
|
||||
base_name=base_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to index table '{table.name}' ({table.id}) "
|
||||
f"in base '{base_name}' ({base_id}), skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -171,6 +171,7 @@ def process_onyx_metadata(
|
||||
|
||||
return (
|
||||
OnyxMetadata(
|
||||
document_id=metadata.get("id"),
|
||||
source_type=source_type,
|
||||
link=metadata.get("link"),
|
||||
file_display_name=metadata.get("file_display_name"),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -107,7 +108,7 @@ def _process_file(
|
||||
# These metadata items are not settable by the user
|
||||
source_type = onyx_metadata.source_type or DocumentSource.FILE
|
||||
|
||||
doc_id = f"FILE_CONNECTOR__{file_id}"
|
||||
doc_id = onyx_metadata.document_id or f"FILE_CONNECTOR__{file_id}"
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
@@ -240,29 +241,49 @@ class LocalFileConnector(LoadConnector):
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
file_names: list[str] | None = None, # noqa: ARG002
|
||||
zip_metadata: dict[str, Any] | None = None,
|
||||
zip_metadata_file_id: str | None = None,
|
||||
zip_metadata: dict[str, Any] | None = None, # Deprecated, for backwards compat
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [str(loc) for loc in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
self.zip_metadata = zip_metadata or {}
|
||||
self._zip_metadata_file_id = zip_metadata_file_id
|
||||
self._zip_metadata_deprecated = zip_metadata
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
|
||||
return None
|
||||
|
||||
def _get_file_metadata(self, file_name: str) -> dict[str, Any]:
|
||||
return self.zip_metadata.get(file_name, {}) or self.zip_metadata.get(
|
||||
os.path.basename(file_name), {}
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Iterates over each file path, fetches from Postgres, tries to parse text
|
||||
or images, and yields Document batches.
|
||||
"""
|
||||
# Load metadata dict at start (from file store or deprecated inline format)
|
||||
zip_metadata: dict[str, Any] = {}
|
||||
if self._zip_metadata_file_id:
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
metadata_io = file_store.read_file(
|
||||
file_id=self._zip_metadata_file_id, mode="b"
|
||||
)
|
||||
metadata_bytes = metadata_io.read()
|
||||
loaded_metadata = json.loads(metadata_bytes)
|
||||
if isinstance(loaded_metadata, list):
|
||||
zip_metadata = {d["filename"]: d for d in loaded_metadata}
|
||||
else:
|
||||
zip_metadata = loaded_metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata from file store: {e}")
|
||||
elif self._zip_metadata_deprecated:
|
||||
logger.warning(
|
||||
"Using deprecated inline zip_metadata dict. "
|
||||
"Re-upload files to use the new file store format."
|
||||
)
|
||||
zip_metadata = self._zip_metadata_deprecated
|
||||
|
||||
documents: list[Document | HierarchyNode] = []
|
||||
|
||||
for file_id in self.file_locations:
|
||||
@@ -273,7 +294,9 @@ class LocalFileConnector(LoadConnector):
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = self._get_file_metadata(file_record.display_name)
|
||||
metadata = zip_metadata.get(
|
||||
file_record.display_name, {}
|
||||
) or zip_metadata.get(os.path.basename(file_record.display_name), {})
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
@@ -298,7 +321,6 @@ if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[os.environ["TEST_FILE"]],
|
||||
file_names=[os.environ["TEST_FILE"]],
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
|
||||
doc_batches = connector.load_from_state()
|
||||
|
||||
@@ -523,6 +523,22 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def fetch_configured_repos(self) -> list[Repository.Repository]:
|
||||
"""
|
||||
Fetch the configured repositories based on the connector settings.
|
||||
|
||||
Returns:
|
||||
list[Repository.Repository]: The configured repositories.
|
||||
"""
|
||||
assert self.github_client is not None # mypy
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
return self.get_github_repos(self.github_client)
|
||||
else:
|
||||
return [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
return self.get_all_repos(self.github_client)
|
||||
|
||||
def _pull_requests_func(
|
||||
self, repo: Repository.Repository
|
||||
) -> Callable[[], PaginatedList[PullRequest]]:
|
||||
@@ -551,17 +567,7 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
|
||||
|
||||
# First run of the connector, fetch all repos and store in checkpoint
|
||||
if checkpoint.cached_repo_ids is None:
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self.get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self.get_all_repos(self.github_client)
|
||||
repos = self.fetch_configured_repos()
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
|
||||
@@ -474,8 +474,9 @@ class ConnectorStopSignal(Exception):
|
||||
|
||||
|
||||
class OnyxMetadata(BaseModel):
|
||||
# Note that doc_id cannot be overriden here as it may cause issues
|
||||
# with the display functionalities in the UI. Ask @chris if clarification is needed.
|
||||
# Careful overriding the document_id, may cause visual issues in the UI.
|
||||
# Kept here for API based use cases mostly
|
||||
document_id: str | None = None
|
||||
source_type: DocumentSource | None = None
|
||||
link: str | None = None
|
||||
file_display_name: str | None = None
|
||||
|
||||
@@ -79,6 +79,13 @@ SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()}
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
|
||||
# The office365 library's ClientContext caches the access token from
|
||||
# The office365 library's ClientContext caches the access token from its
|
||||
# first request and never re-invokes the token callback. Microsoft access
|
||||
# tokens live ~60-75 minutes, so we recreate the cached ClientContext every
|
||||
# 30 minutes to let MSAL transparently handle token refresh.
|
||||
_REST_CTX_MAX_AGE_S = 30 * 60
|
||||
|
||||
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
@@ -114,11 +121,10 @@ def sleep_and_retry(
|
||||
try:
|
||||
return query_obj.execute_query()
|
||||
except ClientRequestException as e:
|
||||
if (
|
||||
e.response is not None
|
||||
and e.response.status_code in [429, 503]
|
||||
and attempt < max_retries
|
||||
):
|
||||
status = e.response.status_code if e.response is not None else None
|
||||
|
||||
# 429 / 503 — rate limit or transient error. Back off and retry.
|
||||
if status in (429, 503) and attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying"
|
||||
)
|
||||
@@ -131,13 +137,15 @@ def sleep_and_retry(
|
||||
|
||||
logger.info(f"Sleeping for {sleep_time} seconds before retry")
|
||||
time.sleep(sleep_time)
|
||||
else:
|
||||
# Either not a rate limit error, or we've exhausted retries
|
||||
if e.response is not None and e.response.status_code == 429:
|
||||
logger.error(
|
||||
f"Rate limit retry exhausted for {method_name} after {max_retries} attempts"
|
||||
)
|
||||
raise e
|
||||
continue
|
||||
|
||||
# Non-retryable error or retries exhausted — log details and raise.
|
||||
if e.response is not None:
|
||||
logger.error(
|
||||
f"SharePoint request failed for {method_name}: "
|
||||
f"status={status}, "
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
@@ -713,6 +721,10 @@ class SharepointConnector(
|
||||
self.include_site_pages = include_site_pages
|
||||
self.include_site_documents = include_site_documents
|
||||
self.sp_tenant_domain: str | None = None
|
||||
self._credential_json: dict[str, Any] | None = None
|
||||
self._cached_rest_ctx: ClientContext | None = None
|
||||
self._cached_rest_ctx_url: str | None = None
|
||||
self._cached_rest_ctx_created_at: float = 0.0
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Validate that at least one content type is enabled
|
||||
@@ -738,6 +750,44 @@ class SharepointConnector(
|
||||
|
||||
return self._graph_client
|
||||
|
||||
def _create_rest_client_context(self, site_url: str) -> ClientContext:
|
||||
"""Return a ClientContext for SharePoint REST API calls, with caching.
|
||||
|
||||
The office365 library's ClientContext caches the access token from its
|
||||
first request and never re-invokes the token callback. We cache the
|
||||
context and recreate it when the site URL changes or after
|
||||
``_REST_CTX_MAX_AGE_S``. On recreation we also call
|
||||
``load_credentials`` to build a fresh MSAL app with an empty token
|
||||
cache, guaranteeing a brand-new token from Azure AD."""
|
||||
elapsed = time.monotonic() - self._cached_rest_ctx_created_at
|
||||
if (
|
||||
self._cached_rest_ctx is not None
|
||||
and self._cached_rest_ctx_url == site_url
|
||||
and elapsed <= _REST_CTX_MAX_AGE_S
|
||||
):
|
||||
return self._cached_rest_ctx
|
||||
|
||||
if self._credential_json:
|
||||
logger.info(
|
||||
"Rebuilding SharePoint REST client context "
|
||||
"(elapsed=%.0fs, site_changed=%s)",
|
||||
elapsed,
|
||||
self._cached_rest_ctx_url != site_url,
|
||||
)
|
||||
self.load_credentials(self._credential_json)
|
||||
|
||||
if not self.msal_app or not self.sp_tenant_domain:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
self._cached_rest_ctx = ClientContext(site_url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
self._cached_rest_ctx_url = site_url
|
||||
self._cached_rest_ctx_created_at = time.monotonic()
|
||||
return self._cached_rest_ctx
|
||||
|
||||
@staticmethod
|
||||
def _strip_share_link_tokens(path: str) -> list[str]:
|
||||
# Share links often include a token prefix like /:f:/r/ or /:x:/r/.
|
||||
@@ -1177,21 +1227,6 @@ class SharepointConnector(
|
||||
# goes over all urls, converts them into SlimDocument objects and then yields them in batches
|
||||
doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for site_descriptor in site_descriptors:
|
||||
ctx: ClientContext | None = None
|
||||
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
if ctx is None:
|
||||
logger.warning("ClientContext is not set, skipping permissions")
|
||||
continue
|
||||
|
||||
site_url = site_descriptor.url
|
||||
|
||||
# Yield site hierarchy node using helper
|
||||
@@ -1230,6 +1265,7 @@ class SharepointConnector(
|
||||
|
||||
try:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_driveitem_to_slim_document(
|
||||
driveitem, drive_name, ctx, self.graph_client
|
||||
@@ -1249,6 +1285,7 @@ class SharepointConnector(
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_slim_document(
|
||||
site_page, ctx, self.graph_client
|
||||
@@ -1260,6 +1297,7 @@ class SharepointConnector(
|
||||
yield doc_batch
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._credential_json = credentials
|
||||
auth_method = credentials.get(
|
||||
"authentication_method", SharepointAuthMethod.CLIENT_SECRET.value
|
||||
)
|
||||
@@ -1676,17 +1714,6 @@ class SharepointConnector(
|
||||
)
|
||||
logger.debug(f"Time range: {start_dt} to {end_dt}")
|
||||
|
||||
ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
|
||||
# At this point current_drive_name should be set from popleft()
|
||||
current_drive_name = checkpoint.current_drive_name
|
||||
if current_drive_name is None:
|
||||
@@ -1781,6 +1808,10 @@ class SharepointConnector(
|
||||
)
|
||||
|
||||
try:
|
||||
ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
doc = _convert_driveitem_to_document_with_permissions(
|
||||
driveitem,
|
||||
current_drive_name,
|
||||
@@ -1846,20 +1877,13 @@ class SharepointConnector(
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start_dt, end=end_dt
|
||||
)
|
||||
client_ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
if self.msal_app and self.sp_tenant_domain:
|
||||
msal_app = self.msal_app
|
||||
sp_tenant_domain = self.sp_tenant_domain
|
||||
client_ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("MSAL app or tenant domain is not set")
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
client_ctx: ClientContext | None = None
|
||||
if include_permissions:
|
||||
client_ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
yield (
|
||||
_convert_sitepage_to_document(
|
||||
site_page,
|
||||
|
||||
@@ -308,6 +308,18 @@ def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None:
|
||||
return None
|
||||
|
||||
|
||||
def _bot_inclusive_msg_filter(
|
||||
message: MessageType,
|
||||
) -> SlackMessageFilterReason | None:
|
||||
"""Like default_msg_filter but allows bot/app messages through.
|
||||
Only filters out disallowed subtypes (channel_join, channel_leave, etc.).
|
||||
"""
|
||||
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||
return SlackMessageFilterReason.DISALLOWED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def filter_channels(
|
||||
all_channels: list[ChannelType],
|
||||
channels_to_connect: list[str] | None,
|
||||
@@ -654,12 +666,18 @@ class SlackConnector(
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
# if True, messages from bots/apps will be indexed instead of filtered out
|
||||
include_bot_messages: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = True,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.include_bot_messages = include_bot_messages
|
||||
self.msg_filter_func = (
|
||||
_bot_inclusive_msg_filter if include_bot_messages else default_msg_filter
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.num_threads = num_threads
|
||||
self.client: WebClient | None = None
|
||||
@@ -839,6 +857,7 @@ class SlackConnector(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
callback=callback,
|
||||
workspace_url=self._workspace_url,
|
||||
)
|
||||
@@ -926,6 +945,7 @@ class SlackConnector(
|
||||
|
||||
try:
|
||||
num_bot_filtered_messages = 0
|
||||
num_other_filtered_messages = 0
|
||||
|
||||
oldest = str(start) if start else None
|
||||
latest = str(end)
|
||||
@@ -984,6 +1004,7 @@ class SlackConnector(
|
||||
user_cache=self.user_cache,
|
||||
seen_thread_ts=seen_thread_ts,
|
||||
channel_access=checkpoint.current_channel_access,
|
||||
msg_filter_func=self.msg_filter_func,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1003,7 +1024,13 @@ class SlackConnector(
|
||||
|
||||
seen_thread_ts.add(thread_or_message_ts)
|
||||
elif processed_slack_message.filter_reason:
|
||||
num_bot_filtered_messages += 1
|
||||
if (
|
||||
processed_slack_message.filter_reason
|
||||
== SlackMessageFilterReason.BOT
|
||||
):
|
||||
num_bot_filtered_messages += 1
|
||||
else:
|
||||
num_other_filtered_messages += 1
|
||||
elif failure:
|
||||
yield failure
|
||||
|
||||
@@ -1023,10 +1050,14 @@ class SlackConnector(
|
||||
range_total = 1
|
||||
range_percent_complete = range_complete / range_total * 100.0
|
||||
|
||||
logger.info(
|
||||
num_filtered = num_bot_filtered_messages + num_other_filtered_messages
|
||||
log_func = logger.warning if num_bot_filtered_messages > 0 else logger.info
|
||||
log_func(
|
||||
f"Message processing stats: "
|
||||
f"batch_len={len(message_batch)} "
|
||||
f"batch_yielded={num_threads_processed} "
|
||||
f"filtered={num_filtered} "
|
||||
f"(bot={num_bot_filtered_messages} other={num_other_filtered_messages}) "
|
||||
f"total_threads_seen={len(seen_thread_ts)}"
|
||||
)
|
||||
|
||||
@@ -1040,7 +1071,8 @@ class SlackConnector(
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_oldest
|
||||
|
||||
# bypass channels where the first set of messages seen are all bots
|
||||
# bypass channels where the first set of messages seen are all
|
||||
# filtered (bots + disallowed subtypes like channel_join)
|
||||
# check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch
|
||||
# we shouldn't skip based on a small sampling of messages
|
||||
if (
|
||||
@@ -1048,7 +1080,7 @@ class SlackConnector(
|
||||
and len(message_batch) > SlackConnector.BOT_CHANNEL_MIN_BATCH_SIZE
|
||||
):
|
||||
if (
|
||||
num_bot_filtered_messages
|
||||
num_filtered
|
||||
> SlackConnector.BOT_CHANNEL_PERCENTAGE_THRESHOLD
|
||||
* len(message_batch)
|
||||
):
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
from onyx.tracing.llm_utils import llm_generation_span
|
||||
from onyx.tracing.llm_utils import record_llm_span_output
|
||||
from onyx.tracing.llm_utils import record_llm_response
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -201,8 +201,8 @@ def extract_date_range_from_query(
|
||||
llm=llm, flow="slack_date_extraction", input_messages=[prompt_msg]
|
||||
) as span_generation:
|
||||
llm_response = llm.invoke(prompt_msg)
|
||||
record_llm_response(span_generation, llm_response)
|
||||
response = llm_response_to_string(llm_response)
|
||||
record_llm_span_output(span_generation, response, llm_response.usage)
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -606,8 +606,8 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
llm=llm, flow="slack_query_expansion", input_messages=[prompt]
|
||||
) as span_generation:
|
||||
llm_response = llm.invoke(prompt)
|
||||
record_llm_response(span_generation, llm_response)
|
||||
response = llm_response_to_string(llm_response)
|
||||
record_llm_span_output(span_generation, response, llm_response.usage)
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
@@ -97,21 +96,6 @@ class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkContext(BaseModel):
|
||||
# If not specified (None), picked up from Persona settings if there is space
|
||||
# if specified (even if 0), it always uses the specified number of chunks above and below
|
||||
chunks_above: int | None = None
|
||||
chunks_below: int | None = None
|
||||
full_doc: bool = False
|
||||
|
||||
@field_validator("chunks_above", "chunks_below")
|
||||
@classmethod
|
||||
def check_non_negative(cls, value: int, field: Any) -> int:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError(f"{field.name} must be non-negative")
|
||||
return value
|
||||
|
||||
|
||||
class BasicChunkRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
# Note, this file and all SavedSearchSettings things are not being used in live code paths (at least at the time of this comment)
|
||||
# Kept around as it may be useful in the future
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_SEARCH_SETTINGS
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_kv_search_settings() -> SavedSearchSettings | None:
|
||||
"""Get all user configured search settings which affect the search pipeline
|
||||
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
|
||||
|
||||
Note: for now we can't cache this value because if the API server is scaled, the cache could be out of sync
|
||||
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
|
||||
reddis then it will be ok. Until then this has some performance implications (though minor)
|
||||
"""
|
||||
kv_store = get_kv_store()
|
||||
try:
|
||||
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
|
||||
except KvKeyNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading search settings: {e}")
|
||||
# Wiping it so that next server startup, it can load the defaults
|
||||
# or the user can set it via the API/UI
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
return None
|
||||
@@ -19,7 +19,6 @@ from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import DocumentRelevance
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
@@ -672,27 +671,6 @@ def set_as_latest_chat_message(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[DBSearchDoc],
|
||||
relevance_summary: DocumentRelevance,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
search_doc.document_id
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
update(DBSearchDoc)
|
||||
.where(DBSearchDoc.id == search_doc.id)
|
||||
.values(
|
||||
is_relevant=relevance_data.relevant,
|
||||
relevance_explanation=relevance_data.content,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _sanitize_for_postgres(value: str) -> str:
|
||||
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
|
||||
sanitized = value.replace("\x00", "")
|
||||
|
||||
@@ -6,6 +6,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -226,6 +228,50 @@ def get_documents_by_ids(
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_documents_by_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
creator_id: UUID | None = None,
|
||||
) -> list[DbDocument]:
|
||||
"""Get all documents associated with a specific source type.
|
||||
|
||||
This queries through the connector relationship to find all documents
|
||||
that were indexed by connectors of the given source type.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
source: The document source type to filter by
|
||||
creator_id: If provided, only return documents from connectors
|
||||
created by this user. Filters via ConnectorCredentialPair.
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DbDocument.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.where(Connector.source == source)
|
||||
)
|
||||
if creator_id is not None:
|
||||
stmt = stmt.where(ConnectorCredentialPair.creator_id == creator_id)
|
||||
stmt = stmt.distinct()
|
||||
documents = db_session.execute(stmt).scalars().all()
|
||||
return list(documents)
|
||||
|
||||
|
||||
def _apply_last_updated_cursor_filter(
|
||||
stmt: Select,
|
||||
cursor_last_modified: datetime | None,
|
||||
@@ -1527,3 +1573,40 @@ def get_document_kg_entities_and_relationships(
|
||||
def get_num_chunks_for_document(db_session: Session, document_id: str) -> int:
|
||||
stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none() or 0
|
||||
|
||||
|
||||
def update_document_metadata__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
doc_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
"""Update the doc_metadata field for a document.
|
||||
|
||||
Note: Does not commit. Caller is responsible for committing.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
document_id: The ID of the document to update
|
||||
doc_metadata: The new metadata dictionary to set
|
||||
"""
|
||||
stmt = (
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == document_id)
|
||||
.values(doc_metadata=doc_metadata)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_document_by_id__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> None:
|
||||
"""Delete a single document and its connector credential pair relationships.
|
||||
|
||||
Note: Does not commit. Caller is responsible for committing.
|
||||
|
||||
This uses delete_documents_complete__no_commit which handles
|
||||
all foreign key relationships (KG entities, relationships, chunk stats,
|
||||
cc pair associations, feedback, tags).
|
||||
"""
|
||||
delete_documents_complete__no_commit(db_session, [document_id])
|
||||
|
||||
@@ -60,7 +60,8 @@ class ProcessingMode(str, PyEnum):
|
||||
"""Determines how documents are processed after fetching."""
|
||||
|
||||
REGULAR = "REGULAR" # Full pipeline: chunk → embed → Vespa
|
||||
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only
|
||||
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only (JSON documents)
|
||||
RAW_BINARY = "RAW_BINARY" # Write raw binary to S3 (no text extraction)
|
||||
|
||||
|
||||
class SyncType(str, PyEnum):
|
||||
@@ -197,6 +198,12 @@ class ThemePreference(str, PyEnum):
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class DefaultAppMode(str, PyEnum):
|
||||
AUTO = "AUTO"
|
||||
CHAT = "CHAT"
|
||||
SEARCH = "SEARCH"
|
||||
|
||||
|
||||
class SwitchoverType(str, PyEnum):
|
||||
REINDEX = "reindex"
|
||||
ACTIVE_ONLY = "active_only"
|
||||
@@ -289,4 +296,4 @@ class HierarchyNodeType(str, PyEnum):
|
||||
class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
EMBEDDINGS = "embeddings"
|
||||
CONTEXTUAL_RAG = "contextual_rag"
|
||||
|
||||
@@ -231,10 +231,11 @@ def upsert_llm_provider(
|
||||
# Set to None if the dict is empty after filtering
|
||||
custom_config = custom_config or None
|
||||
|
||||
api_base = llm_provider_upsert_request.api_base or None
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_base = api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
# TODO: Remove default model name on api change
|
||||
@@ -508,6 +509,12 @@ def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None
|
||||
return fetch_default_model(db_session, LLMModelFlowType.VISION)
|
||||
|
||||
|
||||
def fetch_default_contextual_rag_model(
|
||||
db_session: Session,
|
||||
) -> ModelConfiguration | None:
|
||||
return fetch_default_model(db_session, LLMModelFlowType.CONTEXTUAL_RAG)
|
||||
|
||||
|
||||
def fetch_default_model(
|
||||
db_session: Session,
|
||||
flow_type: LLMModelFlowType,
|
||||
@@ -645,6 +652,73 @@ def update_default_vision_provider(
|
||||
)
|
||||
|
||||
|
||||
def update_no_default_contextual_rag_provider(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
update(LLMModelFlow)
|
||||
.where(
|
||||
LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
LLMModelFlow.is_default == True, # noqa: E712
|
||||
)
|
||||
.values(is_default=False)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_contextual_model(
|
||||
db_session: Session,
|
||||
enable_contextual_rag: bool,
|
||||
contextual_rag_llm_provider: str | None,
|
||||
contextual_rag_llm_name: str | None,
|
||||
) -> None:
|
||||
"""Sets or clears the default contextual RAG model.
|
||||
|
||||
Should be called whenever the PRESENT search settings change
|
||||
(e.g. inline update or FUTURE → PRESENT swap).
|
||||
"""
|
||||
if (
|
||||
not enable_contextual_rag
|
||||
or not contextual_rag_llm_name
|
||||
or not contextual_rag_llm_provider
|
||||
):
|
||||
update_no_default_contextual_rag_provider(db_session=db_session)
|
||||
return
|
||||
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=contextual_rag_llm_provider, db_session=db_session
|
||||
)
|
||||
if not provider:
|
||||
raise ValueError(f"Provider '{contextual_rag_llm_provider}' not found")
|
||||
|
||||
model_config = next(
|
||||
(
|
||||
mc
|
||||
for mc in provider.model_configurations
|
||||
if mc.name == contextual_rag_llm_name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not model_config:
|
||||
raise ValueError(
|
||||
f"Model '{contextual_rag_llm_name}' not found for provider '{contextual_rag_llm_provider}'"
|
||||
)
|
||||
|
||||
add_model_to_flow(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_config.id,
|
||||
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
)
|
||||
_update_default_model(
|
||||
db_session=db_session,
|
||||
provider_id=provider.id,
|
||||
model=contextual_rag_llm_name,
|
||||
flow_type=LLMModelFlowType.CONTEXTUAL_RAG,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers that are in Auto mode."""
|
||||
query = (
|
||||
@@ -759,9 +833,18 @@ def create_new_flow_mapping__no_commit(
|
||||
)
|
||||
|
||||
flow = result.scalar()
|
||||
if not flow:
|
||||
# Row already exists — fetch it
|
||||
flow = db_session.scalar(
|
||||
select(LLMModelFlow).where(
|
||||
LLMModelFlow.model_configuration_id == model_configuration_id,
|
||||
LLMModelFlow.llm_model_flow_type == flow_type,
|
||||
)
|
||||
)
|
||||
if not flow:
|
||||
raise ValueError(
|
||||
f"Failed to create new flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}"
|
||||
f"Failed to create or find flow mapping for "
|
||||
f"model_configuration_id={model_configuration_id} and flow_type={flow_type}"
|
||||
)
|
||||
|
||||
return flow
|
||||
@@ -899,3 +982,18 @@ def _update_default_model(
|
||||
model_config.is_visible = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def add_model_to_flow(
|
||||
db_session: Session,
|
||||
model_configuration_id: int,
|
||||
flow_type: LLMModelFlowType,
|
||||
) -> None:
|
||||
# Function does nothing on conflict
|
||||
create_new_flow_mapping__no_commit(
|
||||
db_session=db_session,
|
||||
model_configuration_id=model_configuration_id,
|
||||
flow_type=flow_type,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
@@ -5,10 +7,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
|
||||
MAX_MEMORIES_PER_USER = 10
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -27,10 +27,20 @@ class UserInfo(BaseModel):
|
||||
class UserMemoryContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
user_id: UUID | None = None
|
||||
user_info: UserInfo
|
||||
user_preferences: str | None = None
|
||||
memories: tuple[str, ...] = ()
|
||||
|
||||
def without_memories(self) -> "UserMemoryContext":
|
||||
"""Return a copy with memories cleared but user info/preferences intact."""
|
||||
return UserMemoryContext(
|
||||
user_id=self.user_id,
|
||||
user_info=self.user_info,
|
||||
user_preferences=self.user_preferences,
|
||||
memories=(),
|
||||
)
|
||||
|
||||
def as_formatted_list(self) -> list[str]:
|
||||
"""Returns combined list of user info, preferences, and memories."""
|
||||
result = []
|
||||
@@ -45,50 +55,8 @@ class UserMemoryContext(BaseModel):
|
||||
result.extend(self.memories)
|
||||
return result
|
||||
|
||||
def as_formatted_prompt(self) -> str:
|
||||
"""Returns structured prompt sections for the system prompt."""
|
||||
has_basic_info = (
|
||||
self.user_info.name or self.user_info.email or self.user_info.role
|
||||
)
|
||||
if not has_basic_info and not self.user_preferences and not self.memories:
|
||||
return ""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
|
||||
if self.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=self.user_info.name or "",
|
||||
user_email=self.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
|
||||
)
|
||||
|
||||
if self.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
return "".join(sections)
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
if not user.use_memories:
|
||||
return UserMemoryContext(user_info=UserInfo())
|
||||
|
||||
user_info = UserInfo(
|
||||
name=user.personal_name,
|
||||
role=user.personal_role,
|
||||
@@ -105,7 +73,57 @@ def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
|
||||
|
||||
return UserMemoryContext(
|
||||
user_id=user.id,
|
||||
user_info=user_info,
|
||||
user_preferences=user_preferences,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
|
||||
def add_memory(
|
||||
user_id: UUID,
|
||||
memory_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory:
|
||||
"""Insert a new Memory row for the given user.
|
||||
|
||||
If the user already has MAX_MEMORIES_PER_USER memories, the oldest
|
||||
one (lowest id) is deleted before inserting the new one.
|
||||
"""
|
||||
existing = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if len(existing) >= MAX_MEMORIES_PER_USER:
|
||||
db_session.delete(existing[0])
|
||||
|
||||
memory = Memory(
|
||||
user_id=user_id,
|
||||
memory_text=memory_text,
|
||||
)
|
||||
db_session.add(memory)
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
|
||||
def update_memory_at_index(
|
||||
user_id: UUID,
|
||||
index: int,
|
||||
new_text: str,
|
||||
db_session: Session,
|
||||
) -> Memory | None:
|
||||
"""Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()).
|
||||
|
||||
Returns the updated Memory row, or None if the index is out of range.
|
||||
"""
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
|
||||
if index < 0 or index >= len(memory_rows):
|
||||
return None
|
||||
|
||||
memory = memory_rows[index]
|
||||
memory.memory_text = new_text
|
||||
db_session.commit()
|
||||
return memory
|
||||
|
||||
@@ -75,6 +75,7 @@ from onyx.db.enums import (
|
||||
MCPServerStatus,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
SwitchoverType,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
@@ -247,10 +248,18 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
default=None,
|
||||
)
|
||||
chat_background: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
default_app_mode: Mapped[DefaultAppMode] = mapped_column(
|
||||
Enum(DefaultAppMode, native_enum=False),
|
||||
nullable=False,
|
||||
default=DefaultAppMode.CHAT,
|
||||
)
|
||||
# personalization fields are exposed via the chat user settings "Personalization" tab
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
enable_memory_tool: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
@@ -312,6 +321,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
order_by="desc(Memory.id)",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
"OAuthUserToken",
|
||||
@@ -1027,6 +1037,31 @@ class OpenSearchTenantMigrationRecord(Base):
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
# Opaque continuation token from Vespa's Visit API.
|
||||
# NULL means "not started" or "visit completed".
|
||||
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
total_chunks_migrated: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
total_chunks_errored: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
total_chunks_in_vespa: Mapped[int] = mapped_column(
|
||||
Integer, default=0, nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
migration_completed_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
@@ -4842,3 +4877,90 @@ class BuildMessage(Base):
|
||||
"ix_build_message_session_turn", "session_id", "turn_index", "created_at"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
SCIM 2.0 Provisioning Models (Enterprise Edition only)
|
||||
Used for automated user/group provisioning from identity providers (Okta, Azure AD).
|
||||
"""
|
||||
|
||||
|
||||
class ScimToken(Base):
|
||||
"""Bearer tokens for IdP SCIM authentication."""
|
||||
|
||||
__tablename__ = "scim_token"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
hashed_token: Mapped[str] = mapped_column(
|
||||
String(64), unique=True, nullable=False
|
||||
) # SHA256 = 64 hex chars
|
||||
token_display: Mapped[str] = mapped_column(
|
||||
String, nullable=False
|
||||
) # Last 4 chars for UI identification
|
||||
|
||||
created_by_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
last_used_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
created_by: Mapped[User] = relationship("User", foreign_keys=[created_by_id])
|
||||
|
||||
|
||||
class ScimUserMapping(Base):
|
||||
"""Maps SCIM externalId from the IdP to an Onyx User."""
|
||||
|
||||
__tablename__ = "scim_user_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
|
||||
class ScimGroupMapping(Base):
|
||||
"""Maps SCIM externalId from the IdP to an Onyx UserGroup."""
|
||||
|
||||
__tablename__ = "scim_group_mapping"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_id: Mapped[str] = mapped_column(String, unique=True, index=True)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
user_group: Mapped[UserGroup] = relationship(
|
||||
"UserGroup", foreign_keys=[user_group_id]
|
||||
)
|
||||
|
||||
@@ -4,6 +4,9 @@ This module provides functions to track the progress of migrating documents
|
||||
from Vespa to OpenSearch.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
@@ -12,10 +15,14 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
from onyx.db.models import OpenSearchTenantMigrationRecord
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -176,7 +183,7 @@ def try_insert_opensearch_tenant_migration_record_with_commit(
|
||||
) -> None:
|
||||
"""Tries to insert the singleton row on OpenSearchTenantMigrationRecord.
|
||||
|
||||
If the row already exists, does nothing.
|
||||
Does nothing if the row already exists.
|
||||
"""
|
||||
stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing(
|
||||
index_elements=[text("(true)")]
|
||||
@@ -190,25 +197,14 @@ def increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
) -> None:
|
||||
"""Increments the number of times observed no additional docs to migrate.
|
||||
|
||||
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
|
||||
starting count, and if the row already exists, increments the count.
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Used to track when to stop the migration task.
|
||||
"""
|
||||
stmt = (
|
||||
insert(OpenSearchTenantMigrationRecord)
|
||||
.values(num_times_observed_no_additional_docs_to_migrate=1)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[text("(true)")],
|
||||
set_={
|
||||
"num_times_observed_no_additional_docs_to_migrate": (
|
||||
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_migrate
|
||||
+ 1
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.num_times_observed_no_additional_docs_to_migrate += 1
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -219,25 +215,14 @@ def increment_num_times_observed_no_additional_docs_to_populate_migration_table_
|
||||
Increments the number of times observed no additional docs to populate the
|
||||
migration table.
|
||||
|
||||
Tries to insert the singleton row on OpenSearchTenantMigrationRecord with a
|
||||
starting count, and if the row already exists, increments the count.
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Used to track when to stop the migration check task.
|
||||
"""
|
||||
stmt = (
|
||||
insert(OpenSearchTenantMigrationRecord)
|
||||
.values(num_times_observed_no_additional_docs_to_populate_migration_table=1)
|
||||
.on_conflict_do_update(
|
||||
index_elements=[text("(true)")],
|
||||
set_={
|
||||
"num_times_observed_no_additional_docs_to_populate_migration_table": (
|
||||
OpenSearchTenantMigrationRecord.num_times_observed_no_additional_docs_to_populate_migration_table
|
||||
+ 1
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.num_times_observed_no_additional_docs_to_populate_migration_table += 1
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -254,3 +239,167 @@ def should_document_migration_be_permanently_failed(
|
||||
>= TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_vespa_visit_state(
|
||||
db_session: Session,
|
||||
) -> tuple[str | None, int]:
|
||||
"""Gets the current Vespa migration state from the tenant migration record.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Returns:
|
||||
Tuple of (continuation_token, total_chunks_migrated). continuation_token
|
||||
is None if not started or completed.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
return (
|
||||
record.vespa_visit_continuation_token,
|
||||
record.total_chunks_migrated,
|
||||
)
|
||||
|
||||
|
||||
def update_vespa_visit_progress_with_commit(
|
||||
db_session: Session,
|
||||
continuation_token: str | None,
|
||||
chunks_processed: int,
|
||||
chunks_errored: int,
|
||||
) -> None:
|
||||
"""Updates the Vespa migration progress and commits.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
continuation_token: The new continuation token. None means the visit
|
||||
is complete.
|
||||
chunks_processed: Number of chunks processed in this batch (added to
|
||||
the running total).
|
||||
chunks_errored: Number of chunks errored in this batch (added to the
|
||||
running errored total).
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.vespa_visit_continuation_token = continuation_token
|
||||
record.total_chunks_migrated += chunks_processed
|
||||
record.total_chunks_errored += chunks_errored
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_migration_completed_time_if_not_set_with_commit(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Marks the migration completed time if not set.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
if record.migration_completed_at is not None:
|
||||
return
|
||||
record.migration_completed_at = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def build_sanitized_to_original_doc_id_mapping(
|
||||
db_session: Session,
|
||||
) -> dict[str, str]:
|
||||
"""Pre-computes a mapping of sanitized -> original document IDs.
|
||||
|
||||
Only includes documents whose ID contains single quotes (the only character
|
||||
that gets sanitized by replace_invalid_doc_id_characters). For all other
|
||||
documents, sanitized == original and no mapping entry is needed.
|
||||
|
||||
Scans over all documents.
|
||||
|
||||
Checks if the sanitized ID already exists as a genuine separate document in
|
||||
the Document table. If so, raises as there is no way of resolving the
|
||||
conflict in the migration. The user will need to reindex.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
Dict mapping sanitized_id -> original_id, only for documents where
|
||||
the IDs differ. Empty dict means no documents have single quotes
|
||||
in their IDs.
|
||||
"""
|
||||
# Find all documents with single quotes in their ID.
|
||||
stmt = select(Document.id).where(Document.id.contains("'"))
|
||||
ids_with_quotes = list(db_session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, str] = {}
|
||||
for original_id in ids_with_quotes:
|
||||
sanitized_id = replace_invalid_doc_id_characters(original_id)
|
||||
if sanitized_id != original_id:
|
||||
result[sanitized_id] = original_id
|
||||
|
||||
# See if there are any documents whose ID is a sanitized ID of another
|
||||
# document. If there is even one match, we cannot proceed.
|
||||
stmt = select(Document.id).where(Document.id.in_(result.keys()))
|
||||
ids_with_matches = list(db_session.scalars(stmt).all())
|
||||
if ids_with_matches:
|
||||
raise RuntimeError(
|
||||
f"Documents with IDs {ids_with_matches} have sanitized IDs that match other documents. "
|
||||
"This is not supported and the user will need to reindex."
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_opensearch_migration_state(
|
||||
db_session: Session,
|
||||
) -> tuple[int, datetime | None, datetime | None]:
|
||||
"""Returns the state of the Vespa to OpenSearch migration.
|
||||
|
||||
If the tenant migration record is not found, returns defaults of 0, None,
|
||||
None.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return 0, None, None
|
||||
return (
|
||||
record.total_chunks_migrated,
|
||||
record.created_at,
|
||||
record.migration_completed_at,
|
||||
)
|
||||
|
||||
|
||||
def get_opensearch_retrieval_state(
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Returns the state of the OpenSearch retrieval.
|
||||
|
||||
If the tenant migration record is not found, defaults to
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
return record.enable_opensearch_retrieval
|
||||
|
||||
|
||||
def set_enable_opensearch_retrieval_with_commit(
|
||||
db_session: Session,
|
||||
enable: bool,
|
||||
) -> None:
|
||||
"""Sets the enable_opensearch_retrieval flag on the singleton record.
|
||||
|
||||
Creates the record if it doesn't exist yet.
|
||||
"""
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.enable_opensearch_retrieval = enable
|
||||
db_session.commit()
|
||||
|
||||
@@ -15,6 +15,8 @@ from onyx.db.index_attempt import (
|
||||
count_unique_active_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
from onyx.db.index_attempt import count_unique_cc_pairs_with_successful_index_attempts
|
||||
from onyx.db.llm import update_default_contextual_model
|
||||
from onyx.db.llm import update_no_default_contextual_rag_provider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -80,6 +82,24 @@ def _perform_index_swap(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update the default contextual model to match the newly promoted settings
|
||||
try:
|
||||
update_default_contextual_model(
|
||||
db_session=db_session,
|
||||
enable_contextual_rag=new_search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_provider=new_search_settings.contextual_rag_llm_provider,
|
||||
contextual_rag_llm_name=new_search_settings.contextual_rag_llm_name,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Model not found, defaulting to no contextual model: {e}")
|
||||
update_no_default_contextual_rag_provider(
|
||||
db_session=db_session,
|
||||
)
|
||||
new_search_settings.enable_contextual_rag = False
|
||||
new_search_settings.contextual_rag_llm_provider = None
|
||||
new_search_settings.contextual_rag_llm_name = None
|
||||
db_session.commit()
|
||||
|
||||
# This flow is for checking and possibly creating an index so we get all
|
||||
# indices.
|
||||
document_indices = get_all_document_indices(new_search_settings, None, None)
|
||||
|
||||
@@ -55,6 +55,8 @@ def get_tools(
|
||||
# To avoid showing rows that have JSON literal `null` stored in the column to the user.
|
||||
# tools from mcp servers will not have an openapi schema but it has `null`, so we need to exclude them.
|
||||
func.jsonb_typeof(Tool.openapi_schema) == "object",
|
||||
# Exclude built-in tools that happen to have an openapi_schema
|
||||
Tool.in_code_tool_id.is_(None),
|
||||
)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
@@ -9,11 +9,13 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -153,13 +155,28 @@ def update_user_chat_background(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_default_app_mode(
|
||||
user_id: UUID,
|
||||
default_app_mode: DefaultAppMode,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's default app mode setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(default_app_mode=default_app_mode)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_personalization(
|
||||
user_id: UUID,
|
||||
*,
|
||||
personal_name: str | None,
|
||||
personal_role: str | None,
|
||||
use_memories: bool,
|
||||
memories: list[str],
|
||||
enable_memory_tool: bool,
|
||||
memories: list[MemoryItem],
|
||||
user_preferences: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -170,15 +187,39 @@ def update_user_personalization(
|
||||
personal_name=personal_name,
|
||||
personal_role=personal_role,
|
||||
use_memories=use_memories,
|
||||
enable_memory_tool=enable_memory_tool,
|
||||
user_preferences=user_preferences,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(delete(Memory).where(Memory.user_id == user_id))
|
||||
# ID-based upsert: use real DB IDs from the frontend to match memories.
|
||||
incoming_ids = {m.id for m in memories if m.id is not None}
|
||||
|
||||
if memories:
|
||||
# Delete existing rows not in the incoming set (scoped to user_id)
|
||||
existing_memories = list(
|
||||
db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all()
|
||||
)
|
||||
existing_ids = {mem.id for mem in existing_memories}
|
||||
ids_to_delete = existing_ids - incoming_ids
|
||||
if ids_to_delete:
|
||||
db_session.execute(
|
||||
delete(Memory).where(
|
||||
Memory.id.in_(ids_to_delete),
|
||||
Memory.user_id == user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Update existing rows whose IDs match
|
||||
existing_by_id = {mem.id: mem for mem in existing_memories}
|
||||
for item in memories:
|
||||
if item.id is not None and item.id in existing_by_id:
|
||||
existing_by_id[item.id].memory_text = item.content
|
||||
|
||||
# Create new rows for items without an ID
|
||||
new_items = [m for m in memories if m.id is None]
|
||||
if new_items:
|
||||
db_session.add_all(
|
||||
[Memory(user_id=user_id, memory_text=memory) for memory in memories]
|
||||
[Memory(user_id=user_id, memory_text=item.content) for item in new_items]
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import run_llm_step_pkt_generator
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import FileToolMetadata
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION
|
||||
@@ -109,6 +110,7 @@ def generate_final_report(
|
||||
user_identity: LLMUserIdentity | None,
|
||||
saved_reasoning: str | None = None,
|
||||
pre_answer_processing_time: float | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> bool:
|
||||
"""Generate the final research report.
|
||||
|
||||
@@ -130,7 +132,7 @@ def generate_final_report(
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=final_reminder,
|
||||
token_count=token_counter(final_reminder),
|
||||
message_type=MessageType.USER,
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
)
|
||||
final_report_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
@@ -139,6 +141,7 @@ def generate_final_report(
|
||||
reminder_message=reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=llm.config.max_input_tokens,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
@@ -194,6 +197,7 @@ def run_deep_research_llm_loop(
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_deep_research_llm_loop",
|
||||
@@ -256,6 +260,7 @@ def run_deep_research_llm_loop(
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
# Calculate tool processing duration for clarification step
|
||||
@@ -304,6 +309,8 @@ def run_deep_research_llm_loop(
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
# Note this is fine to use a USER message type here as it can just be interpretered as a
|
||||
# user's message directly to the LLM.
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
@@ -317,6 +324,7 @@ def run_deep_research_llm_loop(
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
@@ -442,6 +450,7 @@ def run_deep_research_llm_loop(
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=elapsed_seconds,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -450,11 +459,9 @@ def run_deep_research_llm_loop(
|
||||
first_cycle_reminder_message = ChatMessageSimple(
|
||||
message=FIRST_CYCLE_REMINDER,
|
||||
token_count=FIRST_CYCLE_REMINDER_TOKENS,
|
||||
message_type=MessageType.USER,
|
||||
message_type=MessageType.USER_REMINDER,
|
||||
)
|
||||
first_cycle_tokens = FIRST_CYCLE_REMINDER_TOKENS
|
||||
else:
|
||||
first_cycle_tokens = 0
|
||||
first_cycle_reminder_message = None
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
@@ -477,15 +484,13 @@ def run_deep_research_llm_loop(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
reminder_message=first_cycle_reminder_message,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens - first_cycle_tokens,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
|
||||
if first_cycle_reminder_message is not None:
|
||||
truncated_message_history.append(first_cycle_reminder_message)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
@@ -549,6 +554,7 @@ def run_deep_research_llm_loop(
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -572,6 +578,7 @@ def run_deep_research_llm_loop(
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
@@ -644,6 +651,7 @@ def run_deep_research_llm_loop(
|
||||
user_identity=user_identity,
|
||||
pre_answer_processing_time=time.monotonic()
|
||||
- processing_start_time,
|
||||
all_injected_file_metadata=all_injected_file_metadata,
|
||||
)
|
||||
final_turn_index = report_turn_index + (
|
||||
1 if report_reasoned else 0
|
||||
|
||||
151
backend/onyx/document_index/disabled.py
Normal file
151
backend/onyx/document_index/disabled.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""A DocumentIndex implementation that raises on every operation.
|
||||
|
||||
Used as a safety net when DISABLE_VECTOR_DB is True. Any code path that
|
||||
accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import QueryExpansionType
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
VECTOR_DB_DISABLED_ERROR = (
|
||||
"Vector DB is disabled (DISABLE_VECTOR_DB=true). "
|
||||
"This operation requires a vector database."
|
||||
)
|
||||
|
||||
|
||||
class DisabledDocumentIndex(DocumentIndex):
|
||||
"""A DocumentIndex where every method raises RuntimeError.
|
||||
|
||||
Returned by the factory when DISABLE_VECTOR_DB is True so that any
|
||||
accidental vector-DB call surfaces immediately.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str = "disabled",
|
||||
secondary_index_name: str | None = None,
|
||||
*args: Any, # noqa: ARG002
|
||||
**kwargs: Any, # noqa: ARG002
|
||||
) -> None:
|
||||
self.index_name = index_name
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Verifiable
|
||||
# ------------------------------------------------------------------
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
primary_embedding_dim: int, # noqa: ARG002
|
||||
primary_embedding_precision: EmbeddingPrecision, # noqa: ARG002
|
||||
secondary_index_embedding_dim: int | None, # noqa: ARG002
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002
|
||||
) -> None:
|
||||
# No-op: there are no indices to create when the vector DB is disabled.
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
indices: list[str], # noqa: ARG002, ARG004
|
||||
embedding_dims: list[int], # noqa: ARG002, ARG004
|
||||
embedding_precisions: list[EmbeddingPrecision], # noqa: ARG002, ARG004
|
||||
) -> None:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Indexable
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Deletable
|
||||
# ------------------------------------------------------------------
|
||||
def delete_single(
|
||||
self,
|
||||
doc_id: str, # noqa: ARG002
|
||||
*,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None, # noqa: ARG002
|
||||
) -> int:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Updatable
|
||||
# ------------------------------------------------------------------
|
||||
def update_single(
|
||||
self,
|
||||
doc_id: str, # noqa: ARG002
|
||||
*,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
chunk_count: int | None, # noqa: ARG002
|
||||
fields: VespaDocumentFields | None, # noqa: ARG002
|
||||
user_fields: VespaDocumentUserFields | None, # noqa: ARG002
|
||||
) -> None:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# IdRetrievalCapable
|
||||
# ------------------------------------------------------------------
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest], # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
batch_retrieval: bool = False, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HybridCapable
|
||||
# ------------------------------------------------------------------
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str, # noqa: ARG002
|
||||
query_embedding: Embedding, # noqa: ARG002
|
||||
final_keywords: list[str] | None, # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
hybrid_alpha: float, # noqa: ARG002
|
||||
time_decay_multiplier: float, # noqa: ARG002
|
||||
num_to_retrieve: int, # noqa: ARG002
|
||||
ranking_profile_type: QueryExpansionType, # noqa: ARG002
|
||||
title_content_ratio: float | None = None, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AdminCapable
|
||||
# ------------------------------------------------------------------
|
||||
def admin_retrieval(
|
||||
self,
|
||||
query: str, # noqa: ARG002
|
||||
query_embedding: Embedding, # noqa: ARG002
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
num_to_retrieve: int = 10, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# RandomCapable
|
||||
# ------------------------------------------------------------------
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters, # noqa: ARG002
|
||||
num_to_retrieve: int = 10, # noqa: ARG002
|
||||
) -> list[InferenceChunk]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
@@ -1,8 +1,11 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.opensearch_migration import get_opensearch_retrieval_state
|
||||
from onyx.document_index.disabled import DisabledDocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
@@ -14,6 +17,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
def get_default_document_index(
|
||||
search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
db_session: Session,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> DocumentIndex:
|
||||
"""Gets the default document index from env vars.
|
||||
@@ -27,13 +31,24 @@ def get_default_document_index(
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return DisabledDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
secondary_index_name: str | None = None
|
||||
secondary_large_chunks_enabled: bool | None = None
|
||||
if secondary_search_settings:
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
if ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX:
|
||||
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
|
||||
if opensearch_retrieval_enabled:
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
@@ -69,7 +84,24 @@ def get_all_document_indices(
|
||||
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
|
||||
NOTE: Make sure the Vespa index object is returned first. In the rare event
|
||||
that there is some conflict between indexing and the migration task, it is
|
||||
assumed that the state of Vespa is more up-to-date than the state of
|
||||
OpenSearch.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return [
|
||||
DisabledDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
|
||||
@@ -9,6 +9,7 @@ from opensearchpy import TransportError
|
||||
from opensearchpy.helpers import bulk
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
@@ -21,6 +22,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS = 2000
|
||||
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs from
|
||||
# opensearch. By default it emits INFO-level logs for every request.
|
||||
@@ -52,6 +56,30 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
match_highlights: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Recursively replaces vectors in the body with their length.
|
||||
|
||||
TODO(andrei): Do better.
|
||||
|
||||
Args:
|
||||
body: The body to replace the vectors.
|
||||
|
||||
Returns:
|
||||
A copy of body with vectors replaced with their length.
|
||||
"""
|
||||
new_body: dict[str, Any] = {}
|
||||
for k, v in body.items():
|
||||
if k == "vector":
|
||||
new_body[k] = len(v)
|
||||
elif isinstance(v, dict):
|
||||
new_body[k] = get_new_body_without_vectors(v)
|
||||
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
|
||||
new_body[k] = [get_new_body_without_vectors(item) for item in v]
|
||||
else:
|
||||
new_body[k] = v
|
||||
return new_body
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
|
||||
@@ -74,10 +102,11 @@ class OpenSearchClient:
|
||||
use_ssl: bool = True,
|
||||
verify_certs: bool = False,
|
||||
ssl_show_warn: bool = False,
|
||||
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
|
||||
):
|
||||
self._index_name = index_name
|
||||
logger.debug(
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port}."
|
||||
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
|
||||
)
|
||||
self._client = OpenSearch(
|
||||
hosts=[{"host": host, "port": port}],
|
||||
@@ -85,6 +114,13 @@ class OpenSearchClient:
|
||||
use_ssl=use_ssl,
|
||||
verify_certs=verify_certs,
|
||||
ssl_show_warn=ssl_show_warn,
|
||||
# NOTE: This timeout applies to all requests the client makes,
|
||||
# including bulk indexing. When exceeded, the client will raise a
|
||||
# ConnectionTimeout and return no useful results. The OpenSearch
|
||||
# server will log that the client cancelled the request. To get
|
||||
# partial results from OpenSearch, pass in a timeout parameter to
|
||||
# your request body that is less than this value.
|
||||
timeout=timeout,
|
||||
)
|
||||
logger.debug(
|
||||
f"OpenSearch client created successfully for index {self._index_name}."
|
||||
@@ -635,14 +671,31 @@ class OpenSearchClient:
|
||||
f"Trying to search index {self._index_name} with search pipeline {search_pipeline_id}."
|
||||
)
|
||||
result: dict[str, Any]
|
||||
params = {"phase_took": "true"}
|
||||
if search_pipeline_id:
|
||||
result = self._client.search(
|
||||
index=self._index_name, search_pipeline=search_pipeline_id, body=body
|
||||
index=self._index_name,
|
||||
search_pipeline=search_pipeline_id,
|
||||
body=body,
|
||||
params=params,
|
||||
)
|
||||
else:
|
||||
result = self._client.search(index=self._index_name, body=body)
|
||||
result = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
|
||||
hits = self._get_hits_from_search_result(result)
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
phase_took=phase_took,
|
||||
profile=profile,
|
||||
body=body,
|
||||
search_pipeline_id=search_pipeline_id,
|
||||
raise_on_timeout=True,
|
||||
)
|
||||
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
for hit in hits:
|
||||
@@ -698,9 +751,22 @@ class OpenSearchClient:
|
||||
'"_source": False. This query will therefore be inefficient.'
|
||||
)
|
||||
|
||||
result: dict[str, Any] = self._client.search(index=self._index_name, body=body)
|
||||
params = {"phase_took": "true"}
|
||||
result: dict[str, Any] = self._client.search(
|
||||
index=self._index_name, body=body, params=params
|
||||
)
|
||||
|
||||
hits = self._get_hits_from_search_result(result)
|
||||
hits, time_took, timed_out, phase_took, profile = (
|
||||
self._get_hits_and_profile_from_search_result(result)
|
||||
)
|
||||
self._log_search_result_perf(
|
||||
time_took=time_took,
|
||||
timed_out=timed_out,
|
||||
phase_took=phase_took,
|
||||
profile=profile,
|
||||
body=body,
|
||||
raise_on_timeout=True,
|
||||
)
|
||||
|
||||
# TODO(andrei): Implement scroll/point in time for results so that we
|
||||
# can return arbitrarily-many IDs.
|
||||
@@ -737,34 +803,24 @@ class OpenSearchClient:
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True, include_args=True)
|
||||
def set_cluster_auto_create_index_setting(self, enabled: bool) -> bool:
|
||||
"""Sets the cluster auto create index setting.
|
||||
|
||||
By default, when you index a document to a non-existent index,
|
||||
OpenSearch will automatically create the index. This behavior is
|
||||
undesirable so this function exposes the ability to disable it.
|
||||
|
||||
See
|
||||
https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index/#updating-cluster-settings-using-the-api
|
||||
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
|
||||
"""Puts cluster settings.
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable the auto create index setting.
|
||||
settings: The settings to put.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error putting the cluster settings.
|
||||
|
||||
Returns:
|
||||
True if the setting was updated successfully, False otherwise. Does
|
||||
not raise.
|
||||
True if the settings were put successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
body = {"persistent": {"action.auto_create_index": enabled}}
|
||||
response = self._client.cluster.put_settings(body=body)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info(f"Successfully set action.auto_create_index to {enabled}.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to update setting: {response}.")
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Error setting auto_create_index.")
|
||||
response = self._client.cluster.put_settings(body=settings)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info("Successfully put cluster settings.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to put cluster settings: {response}.")
|
||||
return False
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
@@ -788,28 +844,78 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.close()
|
||||
|
||||
def _get_hits_from_search_result(self, result: dict[str, Any]) -> list[Any]:
|
||||
"""Extracts the hits from a search result.
|
||||
def _get_hits_and_profile_from_search_result(
|
||||
self, result: dict[str, Any]
|
||||
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
|
||||
"""Extracts the hits and profiling information from a search result.
|
||||
|
||||
Args:
|
||||
result: The search result to extract the hits from.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error extracting the hits from the search
|
||||
result. This includes the case where the search timed out.
|
||||
result.
|
||||
|
||||
Returns:
|
||||
The hits from the search result.
|
||||
A tuple containing the hits from the search result, the time taken
|
||||
to execute the search in milliseconds, whether the search timed
|
||||
out, the time taken to execute each phase of the search, and the
|
||||
profile.
|
||||
"""
|
||||
if result.get("timed_out", False):
|
||||
raise RuntimeError(f"Search timed out for index {self._index_name}.")
|
||||
time_took: int | None = result.get("took")
|
||||
timed_out: bool | None = result.get("timed_out")
|
||||
phase_took: dict[str, Any] = result.get("phase_took", {})
|
||||
profile: dict[str, Any] = result.get("profile", {})
|
||||
|
||||
hits_first_layer: dict[str, Any] = result.get("hits", {})
|
||||
if not hits_first_layer:
|
||||
raise RuntimeError(
|
||||
f"Hits field missing from response when trying to search index {self._index_name}."
|
||||
)
|
||||
hits_second_layer: list[Any] = hits_first_layer.get("hits", [])
|
||||
return hits_second_layer
|
||||
|
||||
return hits_second_layer, time_took, timed_out, phase_took, profile
|
||||
|
||||
def _log_search_result_perf(
|
||||
self,
|
||||
time_took: int | None,
|
||||
timed_out: bool | None,
|
||||
phase_took: dict[str, Any],
|
||||
profile: dict[str, Any],
|
||||
body: dict[str, Any],
|
||||
search_pipeline_id: str | None = None,
|
||||
raise_on_timeout: bool = False,
|
||||
) -> None:
|
||||
"""Logs the performance of a search result.
|
||||
|
||||
Args:
|
||||
time_took: The time taken to execute the search in milliseconds.
|
||||
timed_out: Whether the search timed out.
|
||||
phase_took: The time taken to execute each phase of the search.
|
||||
profile: The profile for the search.
|
||||
body: The body of the search request for logging.
|
||||
search_pipeline_id: The ID of the search pipeline used for the
|
||||
search, if any, for logging. Defaults to None.
|
||||
raise_on_timeout: Whether to raise an exception if the search timed
|
||||
out. Note that the result may still contain useful partial
|
||||
results. Defaults to False.
|
||||
|
||||
Raises:
|
||||
Exception: If raise_on_timeout is True and the search timed out.
|
||||
"""
|
||||
if time_took and time_took > CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS:
|
||||
logger.warning(
|
||||
f"OpenSearch client warning: Search for index {self._index_name} took {time_took} milliseconds.\n"
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {profile}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
logger.error(error_str)
|
||||
if raise_on_timeout:
|
||||
raise RuntimeError(error_str)
|
||||
|
||||
|
||||
def wait_for_opensearch_with_timeout(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user