Compare commits

..

6 Commits

Author SHA1 Message Date
Dane Urban
5848975679 Remove comment 2026-01-08 19:21:24 -08:00
Dane Urban
dcc330010e Remove comment 2026-01-08 19:21:08 -08:00
Dane Urban
d0f5f1f5ae Handle error and log 2026-01-08 19:20:28 -08:00
Dane Urban
3e475993ff Change which event loop we get 2026-01-08 19:16:12 -08:00
Dane Urban
7c2b5fa822 Change loggin 2026-01-08 17:29:00 -08:00
Dane Urban
409cfdc788 nits 2026-01-08 17:23:08 -08:00
1377 changed files with 29167 additions and 131577 deletions

View File

@@ -8,5 +8,4 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Override Linear Check

File diff suppressed because it is too large Load Diff

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -21,7 +21,7 @@ jobs:
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -29,7 +29,6 @@ jobs:
run: |
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -94,7 +94,7 @@ jobs:
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,9 +45,6 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -128,13 +125,11 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \
cache \
index \
opensearch \
code-interpreter
- name: Run migrations
@@ -163,7 +158,7 @@ jobs:
cd deployment/docker_compose
# Get list of running containers
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
# Collect logs from each container
for container in $containers; do
@@ -177,7 +172,7 @@ jobs:
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/

View File

@@ -88,7 +88,6 @@ jobs:
echo "=== Adding Helm repositories ==="
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo add opensearch https://opensearch-project.github.io/helm-charts
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
@@ -181,11 +180,6 @@ jobs:
trap cleanup EXIT
# Run the actual installation with detailed logging
# Note that opensearch.enabled is true whereas others in this install
# are false. There is some work that needs to be done to get this
# entire step working in CI, enabling opensearch here is a small step
# in that direction. If this is causing issues, disabling it in this
# step should be ok in the short term.
echo "=== Starting ct install ==="
set +e
ct install --all \
@@ -193,8 +187,6 @@ jobs:
--set=nginx.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=opensearch.enabled=true \
--set=auth.opensearch.enabled=true \
--set=slackbot.enabled=false \
--set=postgresql.enabled=true \
--set=postgresql.nameOverride=cloudnative-pg \

View File

@@ -103,7 +103,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -163,7 +163,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -208,7 +208,7 @@ jobs:
persist-credentials: false
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
@@ -310,9 +310,8 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
MCP_SERVER_ENABLED=true
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
- name: Start Docker containers
@@ -439,7 +438,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
@@ -568,7 +567,7 @@ jobs:
- name: Upload logs (multi-tenant)
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-multitenant
path: ${{ github.workspace }}/docker-compose-multitenant.log

View File

@@ -44,7 +44,7 @@ jobs:
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage

View File

@@ -95,7 +95,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -155,7 +155,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
@@ -214,7 +214,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
@@ -301,7 +301,7 @@ jobs:
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
EOF
- name: Start Docker containers
@@ -424,7 +424,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -85,7 +85,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -146,7 +146,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -207,7 +207,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
@@ -435,7 +435,7 @@ jobs:
fi
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
if: always()
with:
# Includes test results and trace.zip files
@@ -455,7 +455,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -50,9 +50,8 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

View File

@@ -5,6 +5,11 @@ on:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
permissions:
contents: read
@@ -26,11 +31,7 @@ env:
jobs:
model-check:
# See https://runs-on.com/runners/linux/
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-model-check"
- "extras=ecr-cache"
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
timeout-minutes: 45
env:
@@ -42,87 +43,108 @@ jobs:
with:
persist-credentials: false
- name: Setup Python and Install Dependencies
uses: ./.github/actions/setup-python-and-install-dependencies
with:
requirements: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
- name: Build and load
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
env:
TAG: model-server-${{ github.run_id }}
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
with:
load: true
targets: model-server
set: |
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
id: start_docker
env:
IMAGE_TAG: model-server-${{ github.run_id }}
run: |
cd deployment/docker_compose
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
up -d --wait \
inference_model_server
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: model-check
title: "🚨 Scheduled Model Tests failed!"
ref-name: ${{ github.ref_name }}
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
REPO: ${{ github.repository }}
RUN_ID: ${{ github.run_id }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log

4
.gitignore vendored
View File

@@ -1,8 +1,5 @@
# editors
.vscode
!/.vscode/env_template.txt
!/.vscode/launch.json
!/.vscode/tasks.template.jsonc
.zed
.cursor
@@ -24,7 +21,6 @@ backend/tests/regression/search_quality/*.json
backend/onyx/evals/data/
backend/onyx/evals/one_off/*.json
*.log
*.csv
# secret files
.env

View File

@@ -11,6 +11,7 @@ repos:
- id: uv-sync
args: ["--locked", "--all-extras"]
- id: uv-lock
files: ^pyproject\.toml$
- id: uv-export
name: uv-export default.txt
args:
@@ -66,8 +67,7 @@ repos:
- id: uv-run
name: Check lazy imports
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
pass_filenames: true
files: ^backend/(?!\.venv/|scripts/).*\.py$
files: ^backend/(?!\.venv/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
@@ -75,13 +75,6 @@ repos:
# pass_filenames: true
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-added-large-files
name: Check for added large files
args: ["--maxkb=1500"]
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:
@@ -154,22 +147,6 @@ repos:
pass_filenames: false
files: \.tf$
- id: npm-install
name: npm install
description: "Automatically run 'npm install' after a checkout, pull or rebase"
language: system
entry: bash -c 'cd web && npm install --no-save'
pass_filenames: false
files: ^web/package(-lock)?\.json$
stages: [post-checkout, post-merge, post-rewrite]
- id: npm-install-check
name: npm install --package-lock-only
description: "Check the 'web/package-lock.json' is updated"
language: system
entry: bash -c 'cd web && npm install --package-lock-only'
pass_filenames: false
files: ^web/package(-lock)?\.json$
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
# This is a preview package - if it breaks:
# 1. Try updating: cd web && npm update @typescript/native-preview

View File

@@ -17,6 +17,12 @@ LOG_ONYX_MODEL_INTERACTIONS=True
LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to
# answer generation.
# This step is quite heavy on token usage so we disable it for dev generally.
DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>

View File

@@ -1,3 +1,5 @@
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
@@ -22,7 +24,7 @@
"Slack Bot",
"Celery primary",
"Celery light",
"Celery heavy",
"Celery background",
"Celery docfetching",
"Celery docprocessing",
"Celery beat"
@@ -149,24 +151,6 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Discord Bot",
"consoleName": "Discord Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/discord/client.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Discord Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",
@@ -415,6 +399,7 @@
"onyx.background.celery.versioned_apps.docfetching",
"worker",
"--pool=threads",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docfetching@%n",
@@ -445,6 +430,7 @@
"onyx.background.celery.versioned_apps.docprocessing",
"worker",
"--pool=threads",
"--concurrency=6",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=docprocessing@%n",
@@ -593,137 +579,6 @@
"group": "3"
}
},
{
"name": "Build Sandbox Templates",
"type": "debugpy",
"request": "launch",
"module": "onyx.server.features.build.sandbox.build_templates",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"console": "integratedTerminal",
"presentation": {
"group": "3"
},
"consoleTitle": "Build Sandbox Templates"
},
{
// Dummy entry used to label the group
"name": "--- Database ---",
"type": "node",
"request": "launch",
"presentation": {
"group": "4",
"order": 0
}
},
{
"name": "Restore seeded database dump",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore seeded database dump (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--fetch-seeded",
"--clean",
"--yes"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Create database snapshot",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"dump",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Clean restore database snapshot (destructive)",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"restore",
"--clean",
"--yes",
"backup.dump"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
"name": "Upgrade database to head revision",
"type": "node",
"request": "launch",
"runtimeExecutable": "uv",
"runtimeArgs": [
"run",
"--with",
"onyx-devtools",
"ods",
"db",
"upgrade"
],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal",
"presentation": {
"group": "4"
}
},
{
// script to generate the openapi schema
"name": "Onyx OpenAPI Schema Generator",

View File

@@ -1,31 +1,262 @@
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
# Contributing to Onyx
Hey there! We are so excited that you're interested in Onyx.
As an open source project in a rapidly changing space, we welcome all contributions.
## Contribution Opportunities
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
## 💃 Guidelines
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
thumb it up if they feel a common need.
### Contribution Opportunities
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
## Contributing Code
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
To contribute, please follow the
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
will be marked with the `approved by maintainers` label.
Issues marked `good first issue` are an especially great place to start.
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
If you have a new/different contribution in mind, we'd love to hear about it!
Your input is vital to making sure that Onyx moves in the right direction.
Before starting on implementation, please raise a GitHub issue.
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
### Contributing Code
To contribute to this project, please follow the
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
That way we can help future contributors and users can avoid the same issue.
We also have support channels and generally interesting discussions on our
[Discord](https://discord.gg/4NA5SbzrWb).
We would love to see you there!
## Get Started 🚀
Onyx being a fully functional app, relies on some external software, specifically:
- [Postgres](https://www.postgresql.org/) (Relational DB)
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
- [Redis](https://redis.io/) (Cache)
- [MinIO](https://min.io/) (File Store)
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
> **Note:**
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
### Local Set Up
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
If using a lower version, modifications will have to be made to the code.
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
#### Backend: Python requirements
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
For convenience here's a command for it:
```bash
uv venv .venv --python 3.11
source .venv/bin/activate
```
_For Windows, activate the virtual environment using Command Prompt:_
```bash
.venv\Scripts\activate
```
If using PowerShell, the command slightly differs:
```powershell
.venv\Scripts\Activate.ps1
```
Install the required python dependencies:
```bash
uv sync --all-extras
```
Install Playwright for Python (headless browser required by the Web Connector):
```bash
uv run playwright install
```
#### Frontend: Node dependencies
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
to manage your Node installations. Once installed, you can run
```bash
nvm install 22 && nvm use 22
node -v # verify your active version
```
Navigate to `onyx/web` and run:
```bash
npm i
```
## Formatting and Linting
### Backend
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
Then run:
```bash
uv run pre-commit install
```
Additionally, we use `mypy` for static type checking.
Onyx is fully type-annotated, and we want to keep it that way!
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
### Web
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
Re-stage your changes and commit again.
# Running the application for development
## Developing using VSCode Debugger (recommended)
**We highly recommend using VSCode debugger for development.**
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
Otherwise, you can follow the instructions below to run the application for development.
## Manually running the application for development
### Docker containers for external software
You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
```bash
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
### Running Onyx locally
To start the frontend, navigate to `onyx/web` and run:
```bash
npm run dev
```
Next, start the model server which runs the local NLP models.
Navigate to `onyx/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
```
The first time running Onyx, you will need to run the DB migrations for Postgres.
After the first time, this is no longer required unless the DB models change.
Navigate to `onyx/backend` and with the venv active, run:
```bash
alembic upgrade head
```
Next, start the task queue which orchestrates the background jobs.
Jobs that take more time are run async from the API server.
Still in `onyx/backend`, run:
```bash
python ./scripts/dev_run_background_jobs.py
```
To run the backend API server, navigate back to `onyx/backend` and run:
```bash
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
$env:AUTH_TYPE='disabled'
uvicorn onyx.main:app --reload --port 8080
"
```
> **Note:**
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
#### Wrapping up
You should now have 4 servers running:
- Web server
- Backend API
- Model server
- Background jobs
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
You've successfully set up a local Onyx instance! 🏁
#### Running the Onyx application in a container
You can run the full Onyx application stack from pre-built images including all external software dependencies.
Navigate to `onyx/deployment/docker_compose` and run:
```bash
docker compose up -d
```
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
```bash
docker compose up -d --build
```
## Getting Help 🙋
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
### Release Process
See you there!
## Release Process
Onyx loosely follows the SemVer versioning standard.
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
A set of Docker containers will be pushed automatically to DockerHub with every tag.

View File

@@ -7,6 +7,8 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
1. **Environment Setup**:
- Copy `.vscode/env_template.txt` to `.vscode/.env`
- Fill in the necessary environment variables in `.vscode/.env`
2. **launch.json**:
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
## Using the Debugger

View File

@@ -16,8 +16,3 @@ dist/
.coverage
htmlcov/
model_server/legacy/
# Craft: demo_data directory should be unzipped at container startup, not copied
**/demo_data/
# Craft: templates/outputs/venv is created at container startup
**/templates/outputs/venv

View File

@@ -37,6 +37,10 @@ CVE-2023-50868
CVE-2023-52425
CVE-2024-28757
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
# No impact in our settings
CVE-2023-7104
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
# Accept the risk
CVE-2023-25193

View File

@@ -7,10 +7,6 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
Edition features outside of personal development or testing purposes. Please reach out to \
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Build argument for Craft support (disabled by default)
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
ARG ENABLE_CRAFT=false
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV DANSWER_RUNNING_IN_DOCKER="true" \
DO_NOT_TRACK="true" \
@@ -50,23 +46,7 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Conditionally install Node.js 20 for Craft (required for Next.js)
# Only installed when ENABLE_CRAFT=true
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Installing Node.js 20 for Craft support..." && \
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
apt-get install -y nodejs && \
rm -rf /var/lib/apt/lists/*; \
fi
# Conditionally install opencode CLI for Craft agent functionality
# Only installed when ENABLE_CRAFT=true
# TODO: download a specific, versioned release of the opencode CLI
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Installing opencode CLI for Craft support..." && \
curl -fsSL https://opencode.ai/install | bash; \
fi
ENV PATH="/root/.opencode/bin:${PATH}"
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
@@ -111,8 +91,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('punkt_tab', quiet=True);"
nltk.download('stopwords', quiet=True); \
nltk.download('punkt_tab', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Pre-downloading tiktoken for setups with limited egress
@@ -139,15 +119,7 @@ COPY --chown=onyx:onyx ./static /app/static
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
# Run Craft template setup at build time when ENABLE_CRAFT=true
# This pre-bakes demo data, Python venv, and npm dependencies into the image
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
echo "Running Craft template setup at build time..." && \
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
fi
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
# Put logo in assets
COPY --chown=onyx:onyx ./assets /app/assets

View File

@@ -225,6 +225,7 @@ def do_run_migrations(
) -> None:
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema_name}"'))
@@ -308,7 +309,6 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
@@ -346,7 +346,6 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:

View File

@@ -1,351 +0,0 @@
"""single onyx craft migration
Consolidates all buildmode/onyx craft tables into a single migration.
Tables created:
- build_session: User build sessions with status tracking
- sandbox: User-owned containerized environments (one per user)
- artifact: Build output files (web apps, documents, images)
- snapshot: Sandbox filesystem snapshots
- build_message: Conversation messages for build sessions
Existing table modified:
- connector_credential_pair: Added processing_mode column
Revision ID: 2020d417ec84
Revises: 41fa44bef321
Create Date: 2026-01-26 14:43:54.641405
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "2020d417ec84"
down_revision = "41fa44bef321"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ==========================================================================
# ENUMS
# ==========================================================================
# Build session status enum
build_session_status_enum = sa.Enum(
"active",
"idle",
name="buildsessionstatus",
native_enum=False,
)
# Sandbox status enum
sandbox_status_enum = sa.Enum(
"provisioning",
"running",
"idle",
"sleeping",
"terminated",
"failed",
name="sandboxstatus",
native_enum=False,
)
# Artifact type enum
artifact_type_enum = sa.Enum(
"web_app",
"pptx",
"docx",
"markdown",
"excel",
"image",
name="artifacttype",
native_enum=False,
)
# ==========================================================================
# BUILD_SESSION TABLE
# ==========================================================================
op.create_table(
"build_session",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("name", sa.String(), nullable=True),
sa.Column(
"status",
build_session_status_enum,
nullable=False,
server_default="active",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"last_activity_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("nextjs_port", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_build_session_user_created",
"build_session",
["user_id", sa.text("created_at DESC")],
unique=False,
)
op.create_index(
"ix_build_session_status",
"build_session",
["status"],
unique=False,
)
# ==========================================================================
# SANDBOX TABLE (user-owned, one per user)
# ==========================================================================
op.create_table(
"sandbox",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("container_id", sa.String(), nullable=True),
sa.Column(
"status",
sandbox_status_enum,
nullable=False,
server_default="provisioning",
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
)
op.create_index(
"ix_sandbox_status",
"sandbox",
["status"],
unique=False,
)
op.create_index(
"ix_sandbox_container_id",
"sandbox",
["container_id"],
unique=False,
)
# ==========================================================================
# ARTIFACT TABLE
# ==========================================================================
op.create_table(
"artifact",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("type", artifact_type_enum, nullable=False),
sa.Column("path", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_artifact_session_created",
"artifact",
["session_id", sa.text("created_at DESC")],
unique=False,
)
op.create_index(
"ix_artifact_type",
"artifact",
["type"],
unique=False,
)
# ==========================================================================
# SNAPSHOT TABLE
# ==========================================================================
op.create_table(
"snapshot",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("storage_path", sa.String(), nullable=False),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_snapshot_session_created",
"snapshot",
["session_id", sa.text("created_at DESC")],
unique=False,
)
# ==========================================================================
# BUILD_MESSAGE TABLE
# ==========================================================================
op.create_table(
"build_message",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"session_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"turn_index",
sa.Integer(),
nullable=False,
),
sa.Column(
"type",
sa.Enum(
"SYSTEM",
"USER",
"ASSISTANT",
"DANSWER",
name="messagetype",
create_type=False,
native_enum=False,
),
nullable=False,
),
sa.Column(
"message_metadata",
postgresql.JSONB(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_build_message_session_turn",
"build_message",
["session_id", "turn_index", sa.text("created_at ASC")],
unique=False,
)
# ==========================================================================
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
# ==========================================================================
op.add_column(
"connector_credential_pair",
sa.Column(
"processing_mode",
sa.String(),
nullable=False,
server_default="regular",
),
)
def downgrade() -> None:
# ==========================================================================
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
# ==========================================================================
op.drop_column("connector_credential_pair", "processing_mode")
# ==========================================================================
# BUILD_MESSAGE TABLE
# ==========================================================================
op.drop_index("ix_build_message_session_turn", table_name="build_message")
op.drop_table("build_message")
# ==========================================================================
# SNAPSHOT TABLE
# ==========================================================================
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
op.drop_table("snapshot")
# ==========================================================================
# ARTIFACT TABLE
# ==========================================================================
op.drop_index("ix_artifact_type", table_name="artifact")
op.drop_index("ix_artifact_session_created", table_name="artifact")
op.drop_table("artifact")
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
# ==========================================================================
# SANDBOX TABLE
# ==========================================================================
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
op.drop_index("ix_sandbox_status", table_name="sandbox")
op.drop_table("sandbox")
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
# ==========================================================================
# BUILD_SESSION TABLE
# ==========================================================================
op.drop_index("ix_build_session_status", table_name="build_session")
op.drop_index("ix_build_session_user_created", table_name="build_session")
op.drop_table("build_session")
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)

View File

@@ -1,42 +0,0 @@
"""add_unique_constraint_to_inputprompt_prompt_user_id
Revision ID: 2c2430828bdf
Revises: fb80bdd256de
Create Date: 2026-01-20 16:01:54.314805
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "2c2430828bdf"
down_revision = "fb80bdd256de"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique constraint on (prompt, user_id) for user-owned prompts
# This ensures each user can only have one shortcut with a given name
op.create_unique_constraint(
"uq_inputprompt_prompt_user_id",
"inputprompt",
["prompt", "user_id"],
)
# Create partial unique index for public prompts (where user_id IS NULL)
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
# so we need a partial index to ensure public prompt names are also unique
op.execute(
"""
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
ON inputprompt (prompt)
WHERE user_id IS NULL
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")

View File

@@ -1,29 +0,0 @@
"""remove default prompt shortcuts
Revision ID: 41fa44bef321
Revises: 2c2430828bdf
Create Date: 2025-01-21
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "41fa44bef321"
down_revision = "2c2430828bdf"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete any user associations for the default prompts first (foreign key constraint)
op.execute(
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
)
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
op.execute("DELETE FROM inputprompt WHERE id < 0")
def downgrade() -> None:
# We don't restore the default prompts on downgrade
pass

View File

@@ -85,122 +85,103 @@ class UserRow(NamedTuple):
def upgrade() -> None:
conn = op.get_bind()
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
# Start transaction
conn.execute(sa.text("BEGIN"))
if search_assistant:
# Update existing Search assistant to be the unified assistant
try:
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
if search_assistant:
# Update existing Search assistant to be the unified assistant
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
)
)
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": search_tool[0]},
)
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": image_gen_tool[0]},
)
if web_search_tool:
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
@@ -209,148 +190,191 @@ def upgrade() -> None:
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
{"tool_id": search_tool[0]},
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
),
{"tool_id": image_gen_tool[0]},
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
if web_search_tool:
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:
conn = op.get_bind()
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
)
)
)
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e

View File

@@ -1,45 +0,0 @@
"""make processing mode default all caps
Revision ID: 72aa7de2e5cf
Revises: 2020d417ec84
Create Date: 2026-01-26 18:58:47.705253
This migration fixes the ProcessingMode enum value mismatch:
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
- The original migration stored lowercase VALUES ('regular', 'file_system')
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "72aa7de2e5cf"
down_revision = "2020d417ec84"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Convert existing lowercase values to uppercase to match enum member names
op.execute(
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
"WHERE processing_mode = 'regular'"
)
op.execute(
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
"WHERE processing_mode = 'file_system'"
)
# Update the server default to use uppercase
op.alter_column(
"connector_credential_pair",
"processing_mode",
server_default="REGULAR",
)
def downgrade() -> None:
# State prior to this was broken, so we don't want to revert back to it
pass

View File

@@ -1,47 +0,0 @@
"""add_search_query_table
Revision ID: 73e9983e5091
Revises: d1b637d7050a
Create Date: 2026-01-14 14:16:52.837489
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "73e9983e5091"
down_revision = "d1b637d7050a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"search_query",
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=True),
sa.ForeignKey("user.id"),
nullable=False,
),
sa.Column("query", sa.String(), nullable=False),
sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.create_index("ix_search_query_user_id", "search_query", ["user_id"])
op.create_index("ix_search_query_created_at", "search_query", ["created_at"])
def downgrade() -> None:
op.drop_index("ix_search_query_created_at", table_name="search_query")
op.drop_index("ix_search_query_user_id", table_name="search_query")
op.drop_table("search_query")

View File

@@ -10,7 +10,8 @@ from alembic import op
import sqlalchemy as sa
from onyx.db.models import IndexModelStatus
from onyx.context.search.enums import RecencyBiasSetting, SearchType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -1,349 +0,0 @@
"""hierarchy_nodes_v1
Revision ID: 81c22b1e2e78
Revises: 72aa7de2e5cf
Create Date: 2026-01-13 18:10:01.021451
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from onyx.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "81c22b1e2e78"
down_revision = "72aa7de2e5cf"
branch_labels = None
depends_on = None
# Human-readable display names for each source
SOURCE_DISPLAY_NAMES: dict[str, str] = {
"ingestion_api": "Ingestion API",
"slack": "Slack",
"web": "Web",
"google_drive": "Google Drive",
"gmail": "Gmail",
"requesttracker": "Request Tracker",
"github": "GitHub",
"gitbook": "GitBook",
"gitlab": "GitLab",
"guru": "Guru",
"bookstack": "BookStack",
"outline": "Outline",
"confluence": "Confluence",
"jira": "Jira",
"slab": "Slab",
"productboard": "Productboard",
"file": "File",
"coda": "Coda",
"notion": "Notion",
"zulip": "Zulip",
"linear": "Linear",
"hubspot": "HubSpot",
"document360": "Document360",
"gong": "Gong",
"google_sites": "Google Sites",
"zendesk": "Zendesk",
"loopio": "Loopio",
"dropbox": "Dropbox",
"sharepoint": "SharePoint",
"teams": "Teams",
"salesforce": "Salesforce",
"discourse": "Discourse",
"axero": "Axero",
"clickup": "ClickUp",
"mediawiki": "MediaWiki",
"wikipedia": "Wikipedia",
"asana": "Asana",
"s3": "S3",
"r2": "R2",
"google_cloud_storage": "Google Cloud Storage",
"oci_storage": "OCI Storage",
"xenforo": "XenForo",
"not_applicable": "Not Applicable",
"discord": "Discord",
"freshdesk": "Freshdesk",
"fireflies": "Fireflies",
"egnyte": "Egnyte",
"airtable": "Airtable",
"highspot": "Highspot",
"drupal_wiki": "Drupal Wiki",
"imap": "IMAP",
"bitbucket": "Bitbucket",
"testrail": "TestRail",
"mock_connector": "Mock Connector",
"user_file": "User File",
}
def upgrade() -> None:
# 1. Create hierarchy_node table
op.create_table(
"hierarchy_node",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("raw_node_id", sa.String(), nullable=False),
sa.Column("display_name", sa.String(), nullable=False),
sa.Column("link", sa.String(), nullable=True),
sa.Column("source", sa.String(), nullable=False),
sa.Column("node_type", sa.String(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("parent_id", sa.Integer(), nullable=True),
# Permission fields - same pattern as Document table
sa.Column(
"external_user_emails",
postgresql.ARRAY(sa.String()),
nullable=True,
),
sa.Column(
"external_user_group_ids",
postgresql.ARRAY(sa.String()),
nullable=True,
),
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
sa.PrimaryKeyConstraint("id"),
# When document is deleted, just unlink (node can exist without document)
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
# When parent node is deleted, orphan children (cleanup via pruning)
sa.ForeignKeyConstraint(
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
),
sa.UniqueConstraint(
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
),
)
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
op.create_index(
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
)
# Add partial unique index to ensure only one SOURCE-type node per source
# This prevents duplicate source root nodes from being created
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
op.execute(
sa.text(
"""
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
ON hierarchy_node (source)
WHERE node_type = 'SOURCE'
"""
)
)
# 2. Create hierarchy_fetch_attempt table
op.create_table(
"hierarchy_fetch_attempt",
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("status", sa.String(), nullable=False),
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("full_exception_trace", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
ondelete="CASCADE",
),
)
op.create_index(
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
)
op.create_index(
"ix_hierarchy_fetch_attempt_time_created",
"hierarchy_fetch_attempt",
["time_created"],
)
op.create_index(
"ix_hierarchy_fetch_attempt_cc_pair",
"hierarchy_fetch_attempt",
["connector_credential_pair_id"],
)
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
# We insert these so every existing document can have a parent hierarchy node
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
# SOURCE nodes are always public since they're just categorical roots.
for source in DocumentSource:
source_name = (
source.name
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
source_value = source.value # e.g., 'google_drive' - the raw_node_id
display_name = SOURCE_DISPLAY_NAMES.get(
source_value, source_value.replace("_", " ").title()
)
op.execute(
sa.text(
"""
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
ON CONFLICT (raw_node_id, source) DO NOTHING
"""
).bindparams(
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
display_name=display_name,
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
)
)
# 4. Add parent_hierarchy_node_id column to document table
op.add_column(
"document",
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
)
# When hierarchy node is deleted, just unlink the document (SET NULL)
op.create_foreign_key(
"fk_document_parent_hierarchy_node",
"document",
"hierarchy_node",
["parent_hierarchy_node_id"],
["id"],
ondelete="SET NULL",
)
op.create_index(
"ix_document_parent_hierarchy_node_id",
"document",
["parent_hierarchy_node_id"],
)
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
op.execute(
sa.text(
"""
UPDATE document d
SET parent_hierarchy_node_id = hn.id
FROM (
-- Get the source for each document (pick MIN connector_id for determinism)
SELECT DISTINCT ON (dbcc.id)
dbcc.id as doc_id,
c.source as source
FROM document_by_connector_credential_pair dbcc
JOIN connector c ON dbcc.connector_id = c.id
ORDER BY dbcc.id, dbcc.connector_id
) doc_source
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
WHERE d.id = doc_source.doc_id
"""
)
)
# Create the persona__hierarchy_node association table
op.create_table(
"persona__hierarchy_node",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["hierarchy_node_id"],
["hierarchy_node.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
)
# Add index for efficient lookups
op.create_index(
"ix_persona__hierarchy_node_hierarchy_node_id",
"persona__hierarchy_node",
["hierarchy_node_id"],
)
# Create the persona__document association table for attaching individual
# documents directly to assistants
op.create_table(
"persona__document",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "document_id"),
)
# Add index for efficient lookups by document_id
op.create_index(
"ix_persona__document_document_id",
"persona__document",
["document_id"],
)
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
),
)
def downgrade() -> None:
# Remove last_time_hierarchy_fetch from connector_credential_pair
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
# Drop persona__document table
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
op.drop_table("persona__document")
# Drop persona__hierarchy_node table
op.drop_index(
"ix_persona__hierarchy_node_hierarchy_node_id",
table_name="persona__hierarchy_node",
)
op.drop_table("persona__hierarchy_node")
# Remove parent_hierarchy_node_id from document
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
op.drop_constraint(
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
)
op.drop_column("document", "parent_hierarchy_node_id")
# Drop hierarchy_fetch_attempt table
op.drop_index(
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
)
op.drop_index(
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
)
op.drop_index(
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
)
op.drop_table("hierarchy_fetch_attempt")
# Drop hierarchy_node table
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
op.drop_table("hierarchy_node")

View File

@@ -1,49 +0,0 @@
"""notifications constraint, sort index, and cleanup old notifications
Revision ID: 8405ca81cc83
Revises: a3c1a7904cd0
Create Date: 2026-01-07 16:43:44.855156
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8405ca81cc83"
down_revision = "a3c1a7904cd0"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create unique index for notification deduplication.
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
#
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
# in unique constraints, but we want NULL == NULL for deduplication).
# The '{}' represents an empty JSONB object as the NULL replacement.
# Clean up legacy notifications first
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
"""
)
# Create index for efficient notification sorting by user
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
op.execute(
"""
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
ON notification (user_id, dismissed, first_shown DESC)
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")

View File

@@ -1,116 +0,0 @@
"""Add Discord bot tables
Revision ID: 8b5ce697290e
Revises: a1b2c3d4e5f7
Create Date: 2025-01-14
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8b5ce697290e"
down_revision = "a1b2c3d4e5f7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# DiscordBotConfig (singleton table - one per tenant)
op.create_table(
"discord_bot_config",
sa.Column(
"id",
sa.String(),
primary_key=True,
server_default=sa.text("'SINGLETON'"),
),
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
)
# DiscordGuildConfig
op.create_table(
"discord_guild_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
sa.Column("guild_name", sa.String(), nullable=True),
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"default_persona_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
),
)
# DiscordChannelConfig
op.create_table(
"discord_channel_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"guild_config_id",
sa.Integer(),
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("channel_id", sa.BigInteger(), nullable=False),
sa.Column("channel_name", sa.String(), nullable=False),
sa.Column(
"channel_type",
sa.String(20),
server_default=sa.text("'text'"),
nullable=False,
),
sa.Column(
"is_private",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"thread_only_mode",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"require_bot_invocation",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"persona_override_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
)
# Unique constraint: one config per channel per guild
op.create_unique_constraint(
"uq_discord_channel_guild_channel",
"discord_channel_config",
["guild_config_id", "channel_id"],
)
def downgrade() -> None:
op.drop_table("discord_channel_config")
op.drop_table("discord_guild_config")
op.drop_table("discord_bot_config")

View File

@@ -42,13 +42,20 @@ TOOL_DESCRIPTIONS = {
def upgrade() -> None:
conn = op.get_bind()
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
conn.execute(sa.text("BEGIN"))
try:
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
conn.execute(sa.text("COMMIT"))
except Exception as e:
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:

View File

@@ -1,47 +0,0 @@
"""drop agent_search_metrics table
Revision ID: a1b2c3d4e5f7
Revises: 73e9983e5091
Create Date: 2026-01-17
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f7"
down_revision = "73e9983e5091"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("agent__search_metrics")
def downgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -7,6 +7,7 @@ Create Date: 2025-12-18 16:00:00.000000
"""
from alembic import op
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
import sqlalchemy as sa
@@ -18,7 +19,7 @@ depends_on = None
DEEP_RESEARCH_TOOL = {
"name": "ResearchAgent",
"name": RESEARCH_AGENT_DB_NAME,
"display_name": "Research Agent",
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
"in_code_tool_id": "ResearchAgent",

View File

@@ -70,66 +70,80 @@ BUILT_IN_TOOLS = [
def upgrade() -> None:
conn = op.get_bind()
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Start transaction
conn.execute(sa.text("BEGIN"))
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
try:
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
)
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:

View File

@@ -1,64 +0,0 @@
"""sync_exa_api_key_to_content_provider
Revision ID: d1b637d7050a
Revises: d25168c2beee
Create Date: 2026-01-09 15:54:15.646249
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "d1b637d7050a"
down_revision = "d25168c2beee"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Exa uses a shared API key between search and content providers.
# For existing Exa search providers with API keys, create the corresponding
# content provider if it doesn't exist yet.
connection = op.get_bind()
# Check if Exa search provider exists with an API key
result = connection.execute(
text(
"""
SELECT api_key FROM internet_search_provider
WHERE provider_type = 'exa' AND api_key IS NOT NULL
LIMIT 1
"""
)
)
row = result.fetchone()
if row:
api_key = row[0]
# Create Exa content provider with the shared key
connection.execute(
text(
"""
INSERT INTO internet_content_provider
(name, provider_type, api_key, is_active)
VALUES ('Exa', 'exa', :api_key, false)
ON CONFLICT (name) DO NOTHING
"""
),
{"api_key": api_key},
)
def downgrade() -> None:
# Remove the Exa content provider that was created by this migration
connection = op.get_bind()
connection.execute(
text(
"""
DELETE FROM internet_content_provider
WHERE provider_type = 'exa'
"""
)
)

View File

@@ -1,86 +0,0 @@
"""tool_name_consistency
Revision ID: d25168c2beee
Revises: 8405ca81cc83
Create Date: 2026-01-11 17:54:40.135777
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d25168c2beee"
down_revision = "8405ca81cc83"
branch_labels = None
depends_on = None
# Currently the seeded tools have the in_code_tool_id == name
CURRENT_TOOL_NAME_MAPPING = [
"SearchTool",
"WebSearchTool",
"ImageGenerationTool",
"PythonTool",
"OpenURLTool",
"KnowledgeGraphTool",
"ResearchAgent",
]
# Mapping of in_code_tool_id -> name
# These are the expected names that we want in the database
EXPECTED_TOOL_NAME_MAPPING = {
"SearchTool": "internal_search",
"WebSearchTool": "web_search",
"ImageGenerationTool": "generate_image",
"PythonTool": "python",
"OpenURLTool": "open_url",
"KnowledgeGraphTool": "run_kg_search",
"ResearchAgent": "research_agent",
}
def upgrade() -> None:
conn = op.get_bind()
# Mapping of in_code_tool_id to the NAME constant from each tool class
# These match the .name property of each tool implementation
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
# Update the name column for each tool based on its in_code_tool_id
for in_code_tool_id, expected_name in tool_name_mapping.items():
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :expected_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"expected_name": expected_name,
"in_code_tool_id": in_code_tool_id,
},
)
def downgrade() -> None:
conn = op.get_bind()
# Reverse the migration by setting name back to in_code_tool_id
# This matches the original pattern where name was the class name
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :current_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"current_name": in_code_tool_id,
"in_code_tool_id": in_code_tool_id,
},
)

View File

@@ -1,31 +0,0 @@
"""add chat_background to user
Revision ID: fb80bdd256de
Revises: 8b5ce697290e
Create Date: 2026-01-16 16:15:59.222617
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fb80bdd256de"
down_revision = "8b5ce697290e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"chat_background",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("user", "chat_background")

View File

@@ -109,6 +109,7 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
@@ -122,23 +123,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
POSTHOG_DEBUG_LOGS_ENABLED = (
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
)
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
GATED_TENANTS_KEY = "gated_tenants"
# License enforcement - when True, blocks API access for gated/expired licenses
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
)
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
# Used when MULTI_TENANT=false (self-hosted mode)
CLOUD_DATA_PLANE_URL = os.environ.get(
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
)

View File

@@ -1,73 +0,0 @@
"""Constants for license enforcement.
This file is the single source of truth for:
1. Paths that bypass license enforcement (always accessible)
2. Paths that require an EE license (EE-only features)
Import these constants in both production code and tests to ensure consistency.
"""
# Paths that are ALWAYS accessible, even when license is expired/gated.
# These enable users to:
# /auth - Log in/out (users can't fix billing if locked out of auth)
# /license - Fetch, upload, or check license status
# /health - Health checks for load balancers/orchestrators
# /me - Basic user info needed for UI rendering
# /settings, /enterprise-settings - View app status and branding
# /billing - Unified billing API
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
# /manage/users, /users - User management (needed for seat limit resolution)
# /notifications - Needed for UI to load properly
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
{
"/auth",
"/license",
"/health",
"/me",
"/settings",
"/enterprise-settings",
# Billing endpoints (unified API for both MT and self-hosted)
"/billing",
"/admin/billing",
# Proxy endpoints for self-hosted billing (no tenant context)
"/proxy",
# Legacy tenant billing endpoints (kept for backwards compatibility)
"/tenants/billing-information",
"/tenants/create-customer-portal-session",
"/tenants/create-subscription-session",
# User management - needed to remove users when seat limit exceeded
"/manage/users",
"/manage/admin/users",
"/manage/admin/valid-domains",
"/manage/admin/deactivate-user",
"/manage/admin/delete-user",
"/users",
# Notifications - needed for UI to load properly
"/notifications",
}
)
# EE-only paths that require a valid license.
# Users without a license (community edition) cannot access these.
# These are blocked even when user has never subscribed (no license).
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
{
# User groups and access control
"/manage/admin/user-group",
# Analytics and reporting
"/analytics",
# Query history (admin chat session endpoints)
"/admin/chat-sessions",
"/admin/chat-session-history",
"/admin/query-history",
# Usage reporting/export
"/admin/usage-report",
# Standard answers (canned responses)
"/manage/admin/standard-answer",
# Token rate limits
"/admin/token-rate-limits",
# Evals
"/evals",
}
)

View File

@@ -1,7 +1,6 @@
"""Database and cache operations for the license table."""
from datetime import datetime
from typing import NamedTuple
from sqlalchemy import func
from sqlalchemy import select
@@ -10,7 +9,6 @@ from sqlalchemy.orm import Session
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.auth.schemas import UserRole
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
@@ -25,13 +23,6 @@ LICENSE_METADATA_KEY = "license:metadata"
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
class SeatAvailabilityResult(NamedTuple):
"""Result of a seat availability check."""
available: bool
error_message: str | None = None
# -----------------------------------------------------------------------------
# Database CRUD Operations
# -----------------------------------------------------------------------------
@@ -104,30 +95,23 @@ def delete_license(db_session: Session) -> bool:
def get_used_seats(tenant_id: str | None = None) -> int:
"""
Get current seat usage directly from database.
Get current seat usage.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
For self-hosted: counts all active users (includes both Onyx UI users
and Slack users who have been converted to Onyx users).
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
return get_tenant_count(tenant_id or get_current_tenant_id())
else:
# Self-hosted: count all active users (Onyx + converted Slack users)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
result = db_session.execute(
select(func.count())
.select_from(User)
.where(
User.is_active == True, # type: ignore # noqa: E712
User.role != UserRole.EXT_PERM_USER,
)
select(func.count()).select_from(User).where(User.is_active) # type: ignore
)
return result.scalar() or 0
@@ -292,43 +276,3 @@ def get_license_metadata(
# Refresh from database
return refresh_license_cache(db_session, tenant_id)
def check_seat_availability(
db_session: Session,
seats_needed: int = 1,
tenant_id: str | None = None,
) -> SeatAvailabilityResult:
"""
Check if there are enough seats available to add users.
Args:
db_session: Database session
seats_needed: Number of seats needed (default 1)
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
SeatAvailabilityResult with available=True if seats are available,
or available=False with error_message if limit would be exceeded.
Returns available=True if no license exists (self-hosted = unlimited).
"""
metadata = get_license_metadata(db_session, tenant_id)
# No license = no enforcement (self-hosted without license)
if metadata is None:
return SeatAvailabilityResult(available=True)
# Calculate current usage directly from DB (not cache) for accuracy
current_used = get_used_seats(tenant_id)
total_seats = metadata.seats
# Use > (not >=) to allow filling to exactly 100% capacity
would_exceed_limit = current_used + seats_needed > total_seats
if would_exceed_limit:
return SeatAvailabilityResult(
available=False,
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
f"cannot add {seats_needed} more user(s).",
)
return SeatAvailabilityResult(available=True)

View File

@@ -3,42 +3,30 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
def update_persona_access(
def make_persona_private(
persona_id: int,
creator_user_id: UUID | None,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
is_public: bool | None = None,
user_ids: list[UUID] | None = None,
group_ids: list[int] | None = None,
) -> None:
"""Updates the access settings for a persona including public status, user shares,
and group shares.
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
NOTE: This function batches all updates. If we don't dedupe the inputs,
the commit will exception.
NOTE: Callers are responsible for committing."""
if is_public is not None:
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
# and a non-empty list means "replace with these shares".
if user_ids is not None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
@@ -53,13 +41,11 @@ def update_persona_access(
).model_dump(),
)
if group_ids is not None:
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
if group_ids:
group_ids_set = set(group_ids)
for group_id in group_ids_set:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)
db_session.commit()

View File

@@ -1,64 +0,0 @@
import uuid
from datetime import timedelta
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.models import SearchQuery
def create_search_query(
db_session: Session,
user_id: UUID,
query: str,
query_expansions: list[str] | None = None,
) -> SearchQuery:
"""Create and persist a `SearchQuery` row.
Notes:
- `SearchQuery.id` is a UUID PK without a server-side default, so we generate it.
- `created_at` is filled by the DB (server_default=now()).
"""
search_query = SearchQuery(
id=uuid.uuid4(),
user_id=user_id,
query=query,
query_expansions=query_expansions,
)
db_session.add(search_query)
db_session.commit()
db_session.refresh(search_query)
return search_query
def fetch_search_queries_for_user(
db_session: Session,
user_id: UUID,
filter_days: int | None = None,
limit: int | None = None,
) -> list[SearchQuery]:
"""Fetch `SearchQuery` rows for a user.
Args:
user_id: User UUID.
filter_days: Optional time filter. If provided, only rows created within
the last `filter_days` days are returned.
limit: Optional max number of rows to return.
"""
if filter_days is not None and filter_days <= 0:
raise ValueError("filter_days must be > 0")
stmt = select(SearchQuery).where(SearchQuery.user_id == user_id)
if filter_days is not None and filter_days > 0:
cutoff = get_db_current_time(db_session) - timedelta(days=filter_days)
stmt = stmt.where(SearchQuery.created_at >= cutoff)
stmt = stmt.order_by(SearchQuery.created_at.desc())
if limit is not None:
stmt = stmt.limit(limit)
return list(db_session.scalars(stmt).all())

View File

@@ -7,7 +7,6 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
from onyx.access.models import DocExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -61,9 +60,6 @@ def gmail_doc_sync(
callback.progress("gmail_doc_sync", 1)
if isinstance(slim_doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if slim_doc.external_access is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -15,7 +15,6 @@ from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -196,9 +195,7 @@ def gdrive_doc_sync(
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
if isinstance(slim_doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if slim_doc.external_access is None:
raise ValueError(
f"Drive perm sync: No external access for document {slim_doc.id}"

View File

@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import HierarchyNode
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call
from onyx.connectors.slack.connector import SlackConnector
@@ -112,9 +111,6 @@ def _get_slack_document_access(
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if isinstance(doc_metadata, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if doc_metadata.external_access is None:
raise ValueError(
f"No external access for document {doc_metadata.id}. "

View File

@@ -5,7 +5,6 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import HierarchyNode
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -50,9 +49,6 @@ def generic_doc_sync(
callback.progress(label, 1)
for doc in doc_batch:
if isinstance(doc, HierarchyNode):
# TODO: handle hierarchynodes during sync
continue
if not doc.external_access:
raise RuntimeError(
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"

View File

@@ -4,10 +4,8 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
from ee.onyx.server.enterprise_settings.api import (
admin_router as enterprise_settings_admin_router,
@@ -18,17 +16,16 @@ from ee.onyx.server.enterprise_settings.api import (
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
add_license_enforcement_middleware,
)
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
)
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
from ee.onyx.server.query_and_chat.query_backend import (
basic_router as ee_query_router,
)
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.seeding import seed_db
@@ -87,11 +84,6 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_api_server_tenant_id_middleware(application, logger)
else:
# License enforcement middleware for self-hosted deployments only
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
# MT deployments use control plane gating via is_tenant_gated() instead
add_license_enforcement_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
@@ -132,7 +124,7 @@ def get_application() -> FastAPI:
# EE only backend APIs
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, ee_query_router)
include_router_with_global_prefix_prepended(application, search_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
@@ -151,13 +143,6 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)

View File

@@ -1,27 +0,0 @@
# Single message is likely most reliable and generally better for this task
# No final reminders at the end since the user query is expected to be short
# If it is not short, it should go into the chat flow so we do not need to account for this.
KEYWORD_EXPANSION_PROMPT = """
Generate a set of keyword-only queries to help find relevant documents for the provided query. \
These queries will be passed to a bm25-based keyword search engine. \
Provide a single query per line (where each query consists of one or more keywords). \
The queries must be purely keywords and not contain any filler natural language. \
The each query should have as few keywords as necessary to represent the user's search intent. \
If there are no useful expansions, simply return the original query with no additional keyword queries. \
CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries.
The user query is:
{user_query}
""".strip()
QUERY_TYPE_PROMPT = """
Determine if the provided query is better suited for a keyword search or a semantic search.
Respond with "keyword" or "semantic" literally and nothing else.
Do not provide any additional text or reasoning to your response.
CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic".
The user query is:
{user_query}
""".strip()

View File

@@ -1,42 +0,0 @@
# ruff: noqa: E501, W605 start
SEARCH_CLASS = "search"
CHAT_CLASS = "chat"
# Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many
# use cases.
SEARCH_CHAT_PROMPT = f"""
Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \
Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}".
# Classification Guidelines:
## {SEARCH_CLASS}
- If the query consists entirely of keywords or query doesn't require any answer from the AI
- If the query is a short statement that seems like a search query rather than a question
- If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document
### Examples of {SEARCH_CLASS} queries:
- Find me the document that goes over the onboarding process for a new hire
- Pull requests since last week
- Sales Runbook AMEA Region
- Procurement process
- Retrieve the PRD for project X
## {CHAT_CLASS}
- If the query is asking a question that requires an answer rather than a document
- If the query is asking for a solution, suggestion, or general help
- If the query is seeking information that is on the web and likely not in a company internal document
- If the query should be answered without any context from additional documents or searches
### Examples of {CHAT_CLASS} queries:
- What led us to win the deal with company X? (seeking answer)
- Google Drive not sync-ing files to my computer (seeking solution)
- Review my email: <whatever the email is> (general help)
- Write me a script to... (general help)
- Cheap flights Europe to Tokyo (information likely found on the web, not internal)
# User Query:
{{user_query}}
REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE.
""".strip()
# ruff: noqa: E501, W605 end

View File

@@ -1,286 +0,0 @@
from collections.abc import Generator
from sqlalchemy.orm import Session
from ee.onyx.db.search import create_search_query
from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
from ee.onyx.server.query_and_chat.models import SearchFullResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket
from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.llm.factory import get_default_llm
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
from onyx.tools.tool_implementations.search.search_utils import (
weighted_reciprocal_rank_fusion,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
# This is just a heuristic that also happens to work well for the UI/UX
# Users would not find it useful to see a huge list of suggested docs
# but more than 1 is also likely good as many questions may target more than 1 doc.
TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3
def _run_single_search(
query: str,
filters: BaseFilters | None,
document_index: DocumentIndex,
user: User | None,
db_session: Session,
num_hits: int | None = None,
) -> list[InferenceChunk]:
"""Execute a single search query and return chunks."""
chunk_search_request = ChunkSearchRequest(
query=query,
user_selected_filters=filters,
limit=num_hits,
)
return search_pipeline(
chunk_search_request=chunk_search_request,
document_index=document_index,
user=user,
persona=None, # No persona for direct search
db_session=db_session,
)
def stream_search_query(
request: SendSearchQueryRequest,
user: User | None,
db_session: Session,
) -> Generator[
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
None,
None,
]:
"""
Core search function that yields streaming packets.
Used by both streaming and non-streaming endpoints.
"""
# Get document index
search_settings = get_current_search_settings(db_session)
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None)
# Determine queries to execute
original_query = request.search_query
keyword_expansions: list[str] = []
if request.run_query_expansion:
try:
llm = get_default_llm()
keyword_expansions = expand_keywords(
user_query=original_query,
llm=llm,
)
if keyword_expansions:
logger.debug(
f"Query expansion generated {len(keyword_expansions)} keyword queries"
)
except Exception as e:
logger.warning(f"Query expansion failed: {e}; using original query only.")
keyword_expansions = []
# Build list of all executed queries for tracking
all_executed_queries = [original_query] + keyword_expansions
# TODO remove this check, user should not be None
if user is not None:
create_search_query(
db_session=db_session,
user_id=user.id,
query=request.search_query,
query_expansions=keyword_expansions if keyword_expansions else None,
)
# Execute search(es)
if not keyword_expansions:
# Single query (original only) - no threading needed
chunks = _run_single_search(
query=original_query,
filters=request.filters,
document_index=document_index,
user=user,
db_session=db_session,
num_hits=request.num_hits,
)
else:
# Multiple queries - run in parallel and merge with RRF
# First query is the original (semantic), rest are keyword expansions
search_functions = [
(
_run_single_search,
(
query,
request.filters,
document_index,
user,
db_session,
request.num_hits,
),
)
for query in all_executed_queries
]
# Run all searches in parallel
all_search_results: list[list[InferenceChunk]] = (
run_functions_tuples_in_parallel(
search_functions,
allow_failures=True,
)
)
# Separate original query results from keyword expansion results
# Note that in rare cases, the original query may have failed and so we may be
# just overweighting one set of keyword results, should be not a big deal though.
original_result = all_search_results[0] if all_search_results else []
keyword_results = all_search_results[1:] if len(all_search_results) > 1 else []
# Build valid results and weights
# Original query (semantic): weight 2.0
# Keyword expansions: weight 1.0 each
valid_results: list[list[InferenceChunk]] = []
weights: list[float] = []
if original_result:
valid_results.append(original_result)
weights.append(2.0)
for keyword_result in keyword_results:
if keyword_result:
valid_results.append(keyword_result)
weights.append(1.0)
if not valid_results:
logger.warning("All parallel searches returned empty results")
chunks = []
else:
chunks = weighted_reciprocal_rank_fusion(
ranked_results=valid_results,
weights=weights,
id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}",
)
# Merge chunks into sections
sections = merge_individual_chunks(chunks)
# Truncate to the requested number of hits
sections = sections[: request.num_hits]
# Apply LLM document selection if requested
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
# llm_selected_doc_ids will be:
# - None if LLM selection was not requested or failed
# - Empty list if LLM selection ran but selected nothing
# - List of doc IDs if LLM selection succeeded
run_llm_selection = (
request.num_docs_fed_to_llm_selection is not None
and request.num_docs_fed_to_llm_selection >= 1
)
llm_selected_doc_ids: list[str] | None = None
llm_selection_failed = False
if run_llm_selection and sections:
try:
llm = get_default_llm()
sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection]
selected_sections, _ = select_sections_for_expansion(
sections=sections_to_evaluate,
user_query=original_query,
llm=llm,
max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION,
try_to_fill_to_max=True,
)
# Extract unique document IDs from selected sections (may be empty)
llm_selected_doc_ids = list(
dict.fromkeys(
section.center_chunk.document_id for section in selected_sections
)
)
logger.debug(
f"LLM document selection evaluated {len(sections_to_evaluate)} sections, "
f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}"
)
except Exception as e:
# Allowing a blanket exception here as this step is not critical and the rest of the results are still valid
logger.warning(f"LLM document selection failed: {e}")
llm_selection_failed = True
elif run_llm_selection and not sections:
# LLM selection requested but no sections to evaluate
llm_selected_doc_ids = []
# Convert to SearchDocWithContent list, optionally including content
search_docs = SearchDocWithContent.from_inference_sections(
sections,
include_content=request.include_content,
is_internet=False,
)
# Yield queries packet
yield SearchQueriesPacket(all_executed_queries=all_executed_queries)
# Yield docs packet
yield SearchDocsPacket(search_docs=search_docs)
# Yield LLM selected docs packet if LLM selection was requested
# - llm_selected_doc_ids is None if selection failed
# - llm_selected_doc_ids is empty list if no docs were selected
# - llm_selected_doc_ids is list of IDs if docs were selected
if run_llm_selection:
yield LLMSelectedDocsPacket(
llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids
)
def gather_search_stream(
packets: Generator[
SearchQueriesPacket
| SearchDocsPacket
| LLMSelectedDocsPacket
| SearchErrorPacket,
None,
None,
],
) -> SearchFullResponse:
"""
Aggregate all streaming packets into SearchFullResponse.
"""
all_executed_queries: list[str] = []
search_docs: list[SearchDocWithContent] = []
llm_selected_doc_ids: list[str] | None = None
error: str | None = None
for packet in packets:
if isinstance(packet, SearchQueriesPacket):
all_executed_queries = packet.all_executed_queries
elif isinstance(packet, SearchDocsPacket):
search_docs = packet.search_docs
elif isinstance(packet, LLMSelectedDocsPacket):
llm_selected_doc_ids = packet.llm_selected_doc_ids
elif isinstance(packet, SearchErrorPacket):
error = packet.error
return SearchFullResponse(
all_executed_queries=all_executed_queries,
search_docs=search_docs,
doc_selection_reasoning=None,
llm_selected_doc_ids=llm_selected_doc_ids,
error=error,
)

View File

@@ -1,92 +0,0 @@
import re
from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT
from onyx.llm.interfaces import LLM
from onyx.llm.models import LanguageModelInput
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import UserMessage
from onyx.llm.utils import llm_response_to_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc.
CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]')
def _clean_keyword_line(line: str) -> str:
"""Clean a keyword line by removing common LLM artifacts.
Removes brackets, quotes, and other characters that LLMs may accidentally
include in their output.
"""
# Remove common artifacts
cleaned = CLEANUP_PATTERN.sub("", line)
# Remove leading list markers like "1.", "2.", "-", "*"
cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned)
return cleaned.strip()
def expand_keywords(
user_query: str,
llm: LLM,
) -> list[str]:
"""Expand a user query into multiple keyword-only queries for BM25 search.
Uses an LLM to generate keyword-based search queries that capture different
aspects of the user's search intent. Returns only the expanded queries,
not the original query.
Args:
user_query: The original search query from the user
llm: Language model to use for keyword expansion
Returns:
List of expanded keyword queries (excluding the original query).
Returns empty list if expansion fails or produces no useful expansions.
"""
messages: LanguageModelInput = [
UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query))
]
try:
response = llm.invoke(
prompt=messages,
reasoning_effort=ReasoningEffort.OFF,
# Limit output - we only expect a few short keyword queries
max_tokens=150,
)
content = llm_response_to_string(response).strip()
if not content:
logger.warning("Keyword expansion returned empty response.")
return []
# Parse response - each line is a separate keyword query
# Clean each line to remove LLM artifacts and drop empty lines
parsed_queries = []
for line in content.strip().split("\n"):
cleaned = _clean_keyword_line(line)
if cleaned:
parsed_queries.append(cleaned)
if not parsed_queries:
logger.warning("Keyword expansion parsing returned no queries.")
return []
# Filter out duplicates and queries that match the original
expanded_queries: list[str] = []
seen_lower: set[str] = {user_query.lower()}
for query in parsed_queries:
query_lower = query.lower()
if query_lower not in seen_lower:
seen_lower.add(query_lower)
expanded_queries.append(query)
logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries")
return expanded_queries
except Exception as e:
logger.warning(f"Keyword expansion failed: {e}")
return []

View File

@@ -1,50 +0,0 @@
from ee.onyx.prompts.search_flow_classification import CHAT_CLASS
from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT
from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS
from onyx.llm.interfaces import LLM
from onyx.llm.models import LanguageModelInput
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import UserMessage
from onyx.llm.utils import llm_response_to_string
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def classify_is_search_flow(
query: str,
llm: LLM,
) -> bool:
messages: LanguageModelInput = [
UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query))
]
response = llm.invoke(
prompt=messages,
reasoning_effort=ReasoningEffort.OFF,
# Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout
timeout_override=2,
# Well more than necessary but just to ensure completion and in case it succeeds with classifying but
# ends up rambling
max_tokens=20,
)
content = llm_response_to_string(response).strip().lower()
if not content:
logger.warning(
"Search flow classification returned empty response; defaulting to chat flow."
)
return False
# Prefer chat if both appear.
if CHAT_CLASS in content:
return False
if SEARCH_CLASS in content:
return True
logger.warning(
"Search flow classification returned unexpected response; defaulting to chat flow. Response=%r",
content,
)
return False

View File

@@ -19,9 +19,9 @@ from ee.onyx.db.analytics import fetch_query_analytics
from ee.onyx.db.analytics import user_can_view_assistant_stats
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.utils import PUBLIC_API_TAGS
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)

View File

@@ -10,16 +10,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
("/enterprise-settings/logo", {"GET"}),
("/enterprise-settings/logotype", {"GET"}),
("/enterprise-settings/custom-analytics-script", {"GET"}),
# Stripe publishable key is safe to expose publicly
("/tenants/stripe-publishable-key", {"GET"}),
("/admin/billing/stripe-publishable-key", {"GET"}),
# Proxy endpoints use license-based auth, not user auth
("/proxy/create-checkout-session", {"POST"}),
("/proxy/claim-license", {"POST"}),
("/proxy/create-customer-portal-session", {"POST"}),
("/proxy/billing-information", {"GET"}),
("/proxy/license/{tenant_id}", {"GET"}),
("/proxy/seats/update", {"POST"}),
]

View File

@@ -1,264 +0,0 @@
"""Unified Billing API endpoints.
These endpoints provide Stripe billing functionality for both cloud and
self-hosted deployments. The service layer routes requests appropriately:
- Self-hosted: Routes through cloud data plane proxy
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
- Cloud (MULTI_TENANT): Routes directly to control plane
Flow: Backend /admin/billing/* → Control plane
License claiming is handled separately by /license/claim endpoint (self-hosted only).
Migration Note (ENG-3533):
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
- /tenants/billing-information -> /admin/billing/billing-information
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
"""
import asyncio
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import get_license
from ee.onyx.server.billing.models import BillingInformationResponse
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import StripePublishableKeyResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.billing.service import BillingServiceError
from ee.onyx.server.billing.service import (
create_checkout_session as create_checkout_service,
)
from ee.onyx.server.billing.service import (
create_customer_portal_session as create_portal_service,
)
from ee.onyx.server.billing.service import (
get_billing_information as get_billing_service,
)
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/admin/billing")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
def _get_license_data(db_session: Session) -> str | None:
"""Get license data from database if exists (self-hosted only)."""
if MULTI_TENANT:
return None
license_record = get_license(db_session)
return license_record.license_data if license_record else None
def _get_tenant_id() -> str | None:
"""Get tenant ID for cloud deployments."""
if MULTI_TENANT:
return get_current_tenant_id()
return None
@router.post("/create-checkout-session")
async def create_checkout_session(
request: CreateCheckoutSessionRequest | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCheckoutSessionResponse:
"""Create a Stripe checkout session for new subscription or renewal.
For new customers, no license/tenant is required.
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
After checkout completion:
- Self-hosted: Use /license/claim to retrieve the license
- Cloud: Subscription is automatically activated
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
billing_period = request.billing_period if request else "monthly"
email = request.email if request else None
# Build redirect URL for after checkout completion
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
try:
return await create_checkout_service(
billing_period=billing_period,
email=email,
license_data=license_data,
redirect_url=redirect_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
request: CreateCustomerPortalSessionRequest | None = None,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CreateCustomerPortalSessionResponse:
"""Create a Stripe customer portal session for managing subscription.
Requires existing license (self-hosted) or active tenant (cloud).
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
try:
return await create_portal_service(
license_data=license_data,
return_url=return_url,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/billing-information")
async def get_billing_information(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> BillingInformationResponse | SubscriptionStatusResponse:
"""Get billing information for the current subscription.
Returns subscription status and details from Stripe.
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted without license = no subscription
if not MULTI_TENANT and not license_data:
return SubscriptionStatusResponse(subscribed=False)
try:
return await get_billing_service(
license_data=license_data,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.post("/seats/update")
async def update_seats(
request: SeatUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUpdateResponse:
"""Update the seat count for the current subscription.
Handles Stripe proration and license regeneration via control plane.
"""
license_data = _get_license_data(db_session)
tenant_id = _get_tenant_id()
# Self-hosted requires license
if not MULTI_TENANT and not license_data:
raise HTTPException(status_code=400, detail="No license found")
try:
return await update_seat_service(
new_seat_count=request.new_seat_count,
license_data=license_data,
tenant_id=tenant_id,
)
except BillingServiceError as e:
raise HTTPException(status_code=e.status_code, detail=e.message)
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -1,75 +0,0 @@
"""Pydantic models for the billing API."""
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
class CreateCheckoutSessionRequest(BaseModel):
"""Request to create a Stripe checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
email: str | None = None
class CreateCheckoutSessionResponse(BaseModel):
"""Response containing the Stripe checkout session URL."""
stripe_checkout_url: str
class CreateCustomerPortalSessionRequest(BaseModel):
"""Request to create a Stripe customer portal session."""
return_url: str | None = None
class CreateCustomerPortalSessionResponse(BaseModel):
"""Response containing the Stripe customer portal URL."""
stripe_customer_portal_url: str
class BillingInformationResponse(BaseModel):
"""Billing information for the current subscription."""
tenant_id: str
status: str | None = None
plan_type: str | None = None
seats: int | None = None
billing_period: str | None = None
current_period_start: datetime | None = None
current_period_end: datetime | None = None
cancel_at_period_end: bool = False
canceled_at: datetime | None = None
trial_start: datetime | None = None
trial_end: datetime | None = None
payment_method_enabled: bool = False
class SubscriptionStatusResponse(BaseModel):
"""Response when no subscription exists."""
subscribed: bool = False
class SeatUpdateRequest(BaseModel):
"""Request to update seat count."""
new_seat_count: int
class SeatUpdateResponse(BaseModel):
"""Response from seat update operation."""
success: bool
current_seats: int
used_seats: int
message: str | None = None
class StripePublishableKeyResponse(BaseModel):
"""Response containing the Stripe publishable key."""
publishable_key: str

View File

@@ -1,267 +0,0 @@
"""Service layer for billing operations.
This module provides functions for billing operations that route differently
based on deployment type:
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
- Cloud (MULTI_TENANT): Routes directly to control plane
Flow: Cloud backend → Control plane
"""
from typing import Literal
import httpx
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
from ee.onyx.server.billing.models import BillingInformationResponse
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.billing.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
# HTTP request timeout for billing service calls
_REQUEST_TIMEOUT = 30.0
class BillingServiceError(Exception):
"""Exception raised for billing service errors."""
def __init__(self, message: str, status_code: int = 500):
self.message = message
self.status_code = status_code
super().__init__(self.message)
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
"""Build headers for proxy requests (self-hosted).
Self-hosted instances authenticate with their license.
"""
headers = {"Content-Type": "application/json"}
if license_data:
headers["Authorization"] = f"Bearer {license_data}"
return headers
def _get_direct_headers() -> dict[str, str]:
"""Build headers for direct control plane requests (cloud).
Cloud instances authenticate with JWT.
"""
token = generate_data_plane_token()
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}",
}
def _get_base_url() -> str:
"""Get the base URL based on deployment type."""
if MULTI_TENANT:
return CONTROL_PLANE_API_BASE_URL
return f"{CLOUD_DATA_PLANE_URL}/proxy"
def _get_headers(license_data: str | None) -> dict[str, str]:
"""Get appropriate headers based on deployment type."""
if MULTI_TENANT:
return _get_direct_headers()
return _get_proxy_headers(license_data)
async def _make_billing_request(
method: Literal["GET", "POST"],
path: str,
license_data: str | None = None,
body: dict | None = None,
params: dict | None = None,
error_message: str = "Billing service request failed",
) -> dict:
"""Make an HTTP request to the billing service.
Consolidates the common HTTP request pattern used by all billing operations.
Args:
method: HTTP method (GET or POST)
path: URL path (appended to base URL)
license_data: License for authentication (self-hosted)
body: Request body for POST requests
params: Query parameters for GET requests
error_message: Default error message if request fails
Returns:
Response JSON as dict
Raises:
BillingServiceError: If request fails
"""
base_url = _get_base_url()
url = f"{base_url}{path}"
headers = _get_headers(license_data)
try:
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
else:
response = await client.post(url, headers=headers, json=body)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
detail = error_message
try:
error_data = e.response.json()
detail = error_data.get("detail", detail)
except Exception:
pass
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
raise BillingServiceError(detail, e.response.status_code)
except httpx.RequestError:
logger.exception("Failed to connect to billing service")
raise BillingServiceError("Failed to connect to billing service", 502)
async def create_checkout_session(
billing_period: str = "monthly",
email: str | None = None,
license_data: str | None = None,
redirect_url: str | None = None,
tenant_id: str | None = None,
) -> CreateCheckoutSessionResponse:
"""Create a Stripe checkout session.
Args:
billing_period: "monthly" or "annual"
email: Customer email for new subscriptions
license_data: Existing license for renewals (self-hosted)
redirect_url: URL to redirect after successful checkout
tenant_id: Tenant ID (cloud only, for renewals)
Returns:
CreateCheckoutSessionResponse with checkout URL
"""
body: dict = {"billing_period": billing_period}
if email:
body["email"] = email
if redirect_url:
body["redirect_url"] = redirect_url
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/create-checkout-session",
license_data=license_data,
body=body,
error_message="Failed to create checkout session",
)
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
async def create_customer_portal_session(
license_data: str | None = None,
return_url: str | None = None,
tenant_id: str | None = None,
) -> CreateCustomerPortalSessionResponse:
"""Create a Stripe customer portal session.
Args:
license_data: License blob for authentication (self-hosted)
return_url: URL to return to after portal session
tenant_id: Tenant ID (cloud only)
Returns:
CreateCustomerPortalSessionResponse with portal URL
"""
body: dict = {}
if return_url:
body["return_url"] = return_url
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/create-customer-portal-session",
license_data=license_data,
body=body,
error_message="Failed to create customer portal session",
)
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
async def get_billing_information(
license_data: str | None = None,
tenant_id: str | None = None,
) -> BillingInformationResponse | SubscriptionStatusResponse:
"""Fetch billing information.
Args:
license_data: License blob for authentication (self-hosted)
tenant_id: Tenant ID (cloud only)
Returns:
BillingInformationResponse or SubscriptionStatusResponse if no subscription
"""
params = {}
if tenant_id and MULTI_TENANT:
params["tenant_id"] = tenant_id
data = await _make_billing_request(
method="GET",
path="/billing-information",
license_data=license_data,
params=params or None,
error_message="Failed to fetch billing information",
)
# Check if no subscription
if isinstance(data, dict) and data.get("subscribed") is False:
return SubscriptionStatusResponse(subscribed=False)
return BillingInformationResponse(**data)
async def update_seat_count(
new_seat_count: int,
license_data: str | None = None,
tenant_id: str | None = None,
) -> SeatUpdateResponse:
"""Update the seat count for the current subscription.
Args:
new_seat_count: New number of seats
license_data: License blob for authentication (self-hosted)
tenant_id: Tenant ID (cloud only)
Returns:
SeatUpdateResponse with updated seat information
"""
body: dict = {"new_seat_count": new_seat_count}
if tenant_id and MULTI_TENANT:
body["tenant_id"] = tenant_id
data = await _make_billing_request(
method="POST",
path="/seats/update",
license_data=license_data,
body=body,
error_message="Failed to update seat count",
)
return SeatUpdateResponse(
success=data.get("success", False),
current_seats=data.get("current_seats", 0),
used_seats=data.get("used_seats", 0),
message=data.get("message"),
)

View File

@@ -1,14 +1,4 @@
"""License API endpoints for self-hosted deployments.
These endpoints allow self-hosted Onyx instances to:
1. Claim a license after Stripe checkout (via cloud data plane proxy)
2. Upload a license file manually (for air-gapped deployments)
3. View license status and seat usage
4. Refresh/delete the local license
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
Cloud licensing is managed via the control plane and gated_tenants Redis key.
"""
"""License API endpoints."""
import requests
from fastapi import APIRouter
@@ -19,7 +9,6 @@ from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license_metadata
from ee.onyx.db.license import invalidate_license_cache
@@ -31,11 +20,13 @@ from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -88,80 +79,81 @@ async def get_seat_usage(
)
@router.post("/claim")
async def claim_license(
session_id: str,
@router.post("/fetch")
async def fetch_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
Claim a license after Stripe checkout (self-hosted only).
After a user completes Stripe checkout, they're redirected back with a
session_id. This endpoint exchanges that session_id for a signed license
via the cloud data plane proxy.
Flow:
1. Self-hosted frontend redirects to Stripe checkout (via cloud proxy)
2. User completes payment
3. Stripe redirects back to self-hosted instance with session_id
4. Frontend calls this endpoint with session_id
5. We call cloud data plane /proxy/claim-license to get the signed license
6. License is stored locally and cached
Fetch license from control plane.
Used after Stripe checkout completion to retrieve the new license.
"""
if MULTI_TENANT:
tenant_id = get_current_tenant_id()
try:
token = generate_data_plane_token()
except ValueError as e:
logger.error(f"Failed to generate data plane token: {e}")
raise HTTPException(
status_code=400,
detail="License claiming is only available for self-hosted deployments",
status_code=500, detail="Authentication configuration error"
)
try:
# Call cloud data plane to claim the license
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
response = requests.post(
url,
json={"session_id": session_id},
headers={"Content-Type": "application/json"},
timeout=30,
)
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
license_data = data.get("license")
if not isinstance(data, dict) or "license" not in data:
raise HTTPException(
status_code=502, detail="Invalid response from control plane"
)
license_data = data["license"]
if not license_data:
raise HTTPException(status_code=404, detail="No license in response")
raise HTTPException(status_code=404, detail="No license found")
# Verify signature before persisting
payload = verify_license_signature(license_data)
# Store in DB
upsert_license(db_session, license_data)
# Verify the fetched license is for this tenant
if payload.tenant_id != tenant_id:
logger.error(
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
)
raise HTTPException(
status_code=400,
detail="License tenant ID mismatch - control plane returned wrong license",
)
# Persist to DB and update cache atomically
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
logger.info(
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
)
return LicenseResponse(success=True, license=payload)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 502
detail = "Failed to claim license"
try:
error_data = e.response.json() if e.response is not None else {}
detail = error_data.get("detail", detail)
except Exception:
pass
raise HTTPException(status_code=status_code, detail=detail)
logger.error(f"Control plane returned error: {status_code}")
raise HTTPException(
status_code=status_code,
detail="Failed to fetch license from control plane",
)
except ValueError as e:
logger.error(f"License verification failed: {type(e).__name__}")
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
logger.exception("Failed to fetch license from control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to license server"
status_code=502, detail="Failed to connect to control plane"
)
@@ -172,36 +164,33 @@ async def upload_license(
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
Upload a license file manually (self-hosted only).
Used for air-gapped deployments where the cloud data plane is not accessible.
The license file must be cryptographically signed by Onyx.
Upload a license file manually.
Used for air-gapped deployments where control plane is not accessible.
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License upload is only available for self-hosted deployments",
)
try:
content = await license_file.read()
license_data = content.decode("utf-8").strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
# Verify cryptographic signature - this is the only validation needed
# The license's tenant_id identifies the customer in control plane, not locally
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tenant_id = get_current_tenant_id()
if payload.tenant_id != tenant_id:
raise HTTPException(
status_code=400,
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseUploadResponse(
@@ -216,10 +205,8 @@ async def refresh_license_cache_endpoint(
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
Force refresh the license cache from the local database.
Force refresh the license cache from the database.
Useful after manual database changes or to verify license validity.
Does NOT fetch from control plane - use /claim for that.
"""
metadata = refresh_license_cache(db_session)
@@ -246,15 +233,9 @@ async def delete_license(
) -> dict[str, bool]:
"""
Delete the current license.
Admin only - removes license from database and invalidates cache.
Admin only - removes license and invalidates cache.
"""
if MULTI_TENANT:
raise HTTPException(
status_code=400,
detail="License deletion is only available for self-hosted deployments",
)
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
try:
invalidate_license_cache()
except Exception as cache_error:

View File

@@ -1,187 +0,0 @@
"""Middleware to enforce license status for SELF-HOSTED deployments only.
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
Multi-tenant gating is handled separately by the control plane via the
/tenants/product-gating endpoint and is_tenant_gated() checks.
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
============================================================
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
- LICENSE_ENFORCEMENT_ENABLED=false (default):
Middleware is disabled. EE features are controlled solely by
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
- LICENSE_ENFORCEMENT_ENABLED=true:
Middleware actively enforces license status. EE features require
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
enforcement will be the only mechanism for gating EE features.
License Enforcement States (when enabled)
=========================================
For self-hosted deployments:
1. No license (never subscribed):
- Allow community features (basic connectors, search, chat)
- Block EE-only features (analytics, user groups, etc.)
2. GATED_ACCESS (fully expired):
- Block all routes except billing/auth/license
- User must renew subscription to continue
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
- Full access to all EE features
- Seat limits enforced
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
"""
import logging
from collections.abc import Awaitable
from collections.abc import Callable
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
from ee.onyx.configs.license_enforcement_config import (
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
)
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from shared_configs.contextvars import get_current_tenant_id
def _is_path_allowed(path: str) -> bool:
"""Check if path is in allowlist (prefix match)."""
return any(
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
)
def _is_ee_only_path(path: str) -> bool:
"""Check if path requires EE license (prefix match)."""
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
def add_license_enforcement_middleware(
app: FastAPI, logger: logging.LoggerAdapter
) -> None:
logger.info("License enforcement middleware registered")
@app.middleware("http")
async def enforce_license(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Block requests when license is expired/gated."""
if not LICENSE_ENFORCEMENT_ENABLED:
return await call_next(request)
path = request.url.path
if path.startswith("/api"):
path = path[4:]
if _is_path_allowed(path):
return await call_next(request)
is_gated = False
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
# If no cached metadata, check database (cache may have been cleared)
if not metadata:
logger.debug(
"[license_enforcement] No cached license, checking database..."
)
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
if metadata:
logger.info(
"[license_enforcement] Loaded license from database"
)
except SQLAlchemyError as db_error:
logger.warning(
f"[license_enforcement] Failed to check database for license: {db_error}"
)
if metadata:
# User HAS a license (current or expired)
if metadata.status == ApplicationStatus.GATED_ACCESS:
# License fully expired - gate the user
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
# they don't block access
is_gated = True
else:
# License is active - check seat limit
# used_seats in cache is kept accurate via invalidation
# when users are added/removed
if metadata.used_seats > metadata.seats:
logger.info(
f"[license_enforcement] Blocking request: "
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
)
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "seat_limit_exceeded",
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
"used_seats": metadata.used_seats,
"seats": metadata.seats,
}
},
)
else:
# No license in cache OR database = never subscribed
# Allow community features, but block EE-only features
if _is_ee_only_path(path):
logger.info(
f"[license_enforcement] Blocking EE-only path (no license): {path}"
)
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "enterprise_license_required",
"message": "This feature requires an Enterprise license. "
"Please upgrade to access this functionality.",
}
},
)
logger.debug(
"[license_enforcement] No license, allowing community features"
)
is_gated = False
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:
logger.info(
f"[license_enforcement] Blocking request (license expired): {path}"
)
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "license_expired",
"message": "Your subscription has expired. Please update your billing.",
}
},
)
return await call_next(request)

View File

@@ -0,0 +1,214 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.models import ChatBasicResponse
from onyx.chat.process_message import gather_stream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.constants import MessageType
from onyx.context.search.models import OptionalSearchSetting
from onyx.context.search.models import RetrievalDetails
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llm_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/chat")
@router.post("/send-message-simple-api")
def handle_simplified_chat_message(
chat_message_req: BasicCreateChatMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information"""
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
# Handle chat session creation if chat_session_id is not provided
if chat_message_req.chat_session_id is None:
if chat_message_req.persona_id is None:
raise HTTPException(
status_code=400,
detail="Either chat_session_id or persona_id must be provided",
)
# Create a new chat session with the provided persona_id
try:
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave empty for simple API
user_id=user.id if user else None,
persona_id=chat_message_req.persona_id,
)
chat_session_id = new_chat_session.id
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
else:
chat_session_id = chat_message_req.chat_session_id
try:
parent_message = create_chat_history_chain(
chat_session_id=chat_session_id, db_session=db_session
)[-1]
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_session_id, db_session=db_session
)
if (
chat_message_req.retrieval_options is None
and chat_message_req.search_doc_ids is None
):
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=chat_message_req.query_override,
# Currently only applies to search flow not chat
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
)
return gather_stream(packets)
@router.post("/send-message-simple-with-history")
def handle_send_message_simple_with_history(
req: BasicCreateChatMessageWithHistoryRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information.
takes in chat history maintained by the caller
and does query rephrasing similar to answer-with-quote"""
if len(req.messages) == 0:
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
# This is a sanity check to make sure the chat history is valid
# It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
raise HTTPException(
status_code=400, detail="One or more chat messages were empty"
)
if msg.role != expected_role:
raise HTTPException(
status_code=400,
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
)
if expected_role == MessageType.USER:
expected_role = MessageType.ASSISTANT
else:
expected_role = MessageType.USER
query = req.messages[-1].message
msg_history = req.messages[:-1]
logger.notice(f"Received new simple with history chat message: {query}")
user_id = user.id if user is not None else None
chat_session = create_chat_session(
db_session=db_session,
description="handle_send_message_simple_with_history",
user_id=user_id,
persona_id=req.persona_id,
)
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
chat_message = root_message
for msg in msg_history:
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=chat_message,
message=msg.message,
token_count=len(llm_tokenizer.encode(msg.message)),
message_type=msg.role,
db_session=db_session,
commit=False,
)
db_session.commit()
if req.retrieval_options is None and req.search_doc_ids is None:
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=chat_message.id,
message=query,
file_descriptors=[],
search_doc_ids=req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=None,
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
)
return gather_stream(packets)

View File

@@ -1,12 +1,18 @@
from collections.abc import Sequence
from datetime import datetime
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import BasicChunkRequest
from onyx.context.search.models import ChunkContext
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RetrievalDetails
from onyx.server.manage.models import StandardAnswer
@@ -19,89 +25,119 @@ class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class SearchFlowClassificationRequest(BaseModel):
user_query: str
class DocumentSearchRequest(BasicChunkRequest):
user_selected_filters: BaseFilters | None = None
class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
class DocumentSearchResponse(BaseModel):
top_documents: list[InferenceChunk]
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
class BasicCreateChatMessageRequest(ChunkContext):
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
Note, for simplicity this option only allows for a single linear chain of messages
"""
include_content: bool = False
stream: bool = False
chat_session_id: UUID | None = None
# Optional persona_id to create a new chat session if chat_session_id is not provided
persona_id: int | None = None
# New message contents
message: str
# Defaults to using retrieval with no additional filters
retrieval_options: RetrievalDetails | None = None
# Allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
raise ValueError("Either chat_session_id or persona_id must be provided")
return self
class SearchDocWithContent(SearchDoc):
# Allows None because this is determined by a flag but the object used in code
# of the search path uses this type
content: str | None
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
messages: list[ThreadMessage]
persona_id: int
retrieval_options: RetrievalDetails | None = None
query_override: str | None = None
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
@classmethod
def from_inference_sections(
cls,
sections: Sequence[InferenceSection],
include_content: bool = False,
is_internet: bool = False,
) -> list["SearchDocWithContent"]:
"""Convert InferenceSections to SearchDocWithContent objects.
Args:
sections: Sequence of InferenceSection objects
include_content: If True, populate content field with combined_content
is_internet: Whether these are internet search results
class SimpleDoc(BaseModel):
id: str
semantic_identifier: str
link: str | None
blurb: str
match_highlights: list[str]
source_type: DocumentSource
metadata: dict | None
Returns:
List of SearchDocWithContent with optional content
class AgentSubQuestion(BaseModel):
sub_question: str
document_ids: list[str]
class AgentAnswer(BaseModel):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(BaseModel):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"],
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
if not sections:
return []
# In this function, when we sort int | None, we deliberately push None to the end
return [
cls(
document_id=(chunk := section.center_chunk).document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links[0] if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
is_internet=is_internet,
content=section.combined_content if include_content else None,
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
for section in sections
]
class SearchFullResponse(BaseModel):
all_executed_queries: list[str]
search_docs: list[SearchDocWithContent]
# Reasoning tokens output by the LLM for the document selection
doc_selection_reasoning: str | None = None
# This a list of document ids that are in the search_docs list
llm_selected_doc_ids: list[str] | None = None
# Error message if the search failed partway through
error: str | None = None
class SearchQueryResponse(BaseModel):
query: str
query_expansions: list[str] | None
created_at: datetime
class SearchHistoryResponse(BaseModel):
search_queries: list[SearchQueryResponse]
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict

View File

@@ -1,170 +0,0 @@
from collections.abc import Generator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ee.onyx.db.search import fetch_search_queries_for_user
from ee.onyx.search.process_search_query import gather_search_stream
from ee.onyx.search.process_search_query import stream_search_query
from ee.onyx.secondary_llm_flows.search_flow_classification import (
classify_is_search_flow,
)
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse
from ee.onyx.server.query_and_chat.models import SearchFullResponse
from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.server.utils import get_json_line
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/search")
@router.post("/search-flow-classification")
def search_flow_classification(
request: SearchFlowClassificationRequest,
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchFlowClassificationResponse:
query = request.user_query
# This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document
# Most likely something needs to be done with the text included so we'll just classify it as a chat flow
if len(query) > 200:
return SearchFlowClassificationResponse(is_search_flow=False)
llm = get_default_llm()
check_llm_cost_limit_for_provider(
db_session=db_session,
tenant_id=get_current_tenant_id(),
llm_provider_api_key=llm.config.api_key,
)
try:
is_search_flow = classify_is_search_flow(query=query, llm=llm)
except Exception as e:
logger.exception(
"Search flow classification failed; defaulting to chat flow",
exc_info=e,
)
is_search_flow = False
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
@router.post("/send-search-message", response_model=None)
def handle_send_search_message(
request: SendSearchQueryRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | SearchFullResponse:
"""
Execute a search query with optional streaming.
When stream=True: Returns StreamingResponse with SSE
When stream=False: Returns SearchFullResponse
"""
logger.debug(f"Received search query: {request.search_query}")
# Non-streaming path
if not request.stream:
try:
packets = stream_search_query(request, user, db_session)
return gather_search_stream(packets)
except NotImplementedError as e:
return SearchFullResponse(
all_executed_queries=[],
search_docs=[],
error=str(e),
)
# Streaming path
def stream_generator() -> Generator[str, None, None]:
try:
with get_session_with_current_tenant() as streaming_db_session:
for packet in stream_search_query(request, user, streaming_db_session):
yield get_json_line(packet.model_dump())
except NotImplementedError as e:
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
except HTTPException:
raise
except Exception as e:
logger.exception("Error in search streaming")
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@router.get("/search-history")
def get_search_history(
limit: int = 100,
filter_days: int | None = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchHistoryResponse:
"""
Fetch past search queries for the authenticated user.
Args:
limit: Maximum number of queries to return (default 100)
filter_days: Only return queries from the last N days (optional)
Returns:
SearchHistoryResponse with list of search queries, ordered by most recent first.
"""
# Validate limit
if limit <= 0:
raise HTTPException(
status_code=400,
detail="limit must be greater than 0",
)
if limit > 1000:
raise HTTPException(
status_code=400,
detail="limit must be at most 1000",
)
# Validate filter_days
if filter_days is not None and filter_days <= 0:
raise HTTPException(
status_code=400,
detail="filter_days must be greater than 0",
)
# TODO(yuhong) remove this
if user is None:
# Return empty list for unauthenticated users
return SearchHistoryResponse(search_queries=[])
search_queries = fetch_search_queries_for_user(
db_session=db_session,
user_id=user.id,
filter_days=filter_days,
limit=limit,
)
return SearchHistoryResponse(
search_queries=[
SearchQueryResponse(
query=sq.query,
query_expansions=sq.query_expansions,
created_at=sq.created_at,
)
for sq in search_queries
]
)

View File

@@ -1,35 +0,0 @@
from typing import Literal
from pydantic import BaseModel
from pydantic import ConfigDict
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
class SearchQueriesPacket(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["search_queries"] = "search_queries"
all_executed_queries: list[str]
class SearchDocsPacket(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["search_docs"] = "search_docs"
search_docs: list[SearchDocWithContent]
class SearchErrorPacket(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["search_error"] = "search_error"
error: str
class LLMSelectedDocsPacket(BaseModel):
model_config = ConfigDict(frozen=True)
type: Literal["llm_selected_docs"] = "llm_selected_docs"
# None if LLM selection failed, empty list if no docs selected, list of IDs otherwise
llm_selected_doc_ids: list[str] | None

View File

@@ -32,7 +32,6 @@ from onyx.configs.constants import MessageType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
@@ -49,6 +48,7 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.threadpool_concurrency import parallel_yield
from shared_configs.contextvars import get_current_tenant_id

View File

@@ -1,83 +0,0 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
def check_ee_features_enabled() -> bool:
"""EE version: checks if EE features should be available.
Returns True if:
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
- Self-hosted with a valid (non-expired) license
Returns False if:
- Self-hosted with no license (never subscribed)
- Self-hosted with expired license
"""
if not LICENSE_ENFORCEMENT_ENABLED:
# License enforcement disabled - allow EE features (legacy behavior)
return True
if MULTI_TENANT:
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
return True
# Self-hosted with enforcement - check for valid license
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
except RedisError as e:
logger.warning(f"Failed to check license for EE features: {e}")
# Fail closed - if Redis is down, other things will break anyway
return False
# No license or GATED_ACCESS - no EE features
return False
def apply_license_status_to_settings(settings: Settings) -> Settings:
"""EE version: checks license status for self-hosted deployments.
For self-hosted, looks up license metadata and overrides application_status
if the license indicates GATED_ACCESS (fully expired).
For multi-tenant (cloud), the settings already have the correct status
from the control plane, so no override is needed.
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
allowing the product to function normally without license checks.
"""
if not LICENSE_ENFORCEMENT_ENABLED:
return settings
if MULTI_TENANT:
return settings
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
# No license = user hasn't purchased yet, allow access for upgrade flow
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")
return settings

View File

@@ -1,14 +1,10 @@
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
import time
import requests
from ee.onyx.server.tenants.access import generate_data_plane_token
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
from onyx.server.usage_limits import NO_LIMIT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -16,12 +12,9 @@ logger = setup_logger()
# In-memory storage for tenant overrides (populated at startup)
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
_last_fetch_time: float = 0.0
_FETCH_INTERVAL = 60 * 60 * 24 # 24 hours
_ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed)
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None:
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
"""
Fetch tenant-specific usage limit overrides from the control plane.
@@ -52,52 +45,33 @@ def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
)
return (
result or None
) # if empty dictionary, something went wrong and we shouldn't enforce limits
return result
except requests.exceptions.RequestException as e:
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
return None
return {}
except Exception as e:
logger.error(f"Error parsing usage limit overrides: {e}")
return None
return {}
def load_usage_limit_overrides() -> None:
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
"""
Load tenant usage limit overrides from the control plane.
Called at server startup to populate the in-memory cache.
"""
global _tenant_usage_limit_overrides
global _last_fetch_time
logger.info("Loading tenant usage limit overrides from control plane...")
overrides = fetch_usage_limit_overrides()
_last_fetch_time = time.time()
# use the new result if it exists, otherwise use the old result
# (prevents us from updating to a failed fetch result)
_tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides
_tenant_usage_limit_overrides = overrides
if overrides:
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
else:
logger.info("No tenant-specific usage limit overrides found")
def unlimited(tenant_id: str) -> TenantUsageLimitOverrides:
return TenantUsageLimitOverrides(
tenant_id=tenant_id,
llm_cost_cents_trial=NO_LIMIT,
llm_cost_cents_paid=NO_LIMIT,
chunks_indexed_trial=NO_LIMIT,
chunks_indexed_paid=NO_LIMIT,
api_calls_trial=NO_LIMIT,
api_calls_paid=NO_LIMIT,
non_streaming_calls_trial=NO_LIMIT,
non_streaming_calls_paid=NO_LIMIT,
)
return overrides
def get_tenant_usage_limit_overrides(
@@ -112,22 +86,7 @@ def get_tenant_usage_limit_overrides(
Returns:
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
"""
if DEV_MODE: # in dev mode, we return unlimited limits for all tenants
return unlimited(tenant_id)
global _tenant_usage_limit_overrides
time_since = time.time() - _last_fetch_time
if (
_tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL
) or (time_since > _FETCH_INTERVAL):
logger.debug(
f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}"
)
load_usage_limit_overrides()
# If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone.
if _tenant_usage_limit_overrides is None or DEV_MODE:
return unlimited(tenant_id)
if _tenant_usage_limit_overrides is None:
_tenant_usage_limit_overrides = load_usage_limit_overrides()
return _tenant_usage_limit_overrides.get(tenant_id)

View File

@@ -3,7 +3,6 @@ from fastapi import APIRouter
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.proxy import router as proxy_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
@@ -23,4 +22,3 @@ router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
router.include_router(proxy_router)

View File

@@ -1,9 +1,9 @@
from typing import cast
from typing import Literal
import requests
import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
@@ -16,21 +16,15 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(
tenant_id: str,
billing_period: Literal["monthly", "annual"] = "monthly",
) -> str:
def fetch_stripe_checkout_session(tenant_id: str) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
payload = {
"tenant_id": tenant_id,
"billing_period": billing_period,
}
response = requests.post(url, headers=headers, json=payload)
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
response.raise_for_status()
return response.json()["sessionId"]
@@ -76,46 +70,24 @@ def fetch_billing_information(
return BillingInformation(**response_data)
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
"""
Fetch a Stripe customer portal session URL from the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
payload = {"tenant_id": tenant_id}
if return_url:
payload["return_url"] = return_url
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["url"]
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Update the number of seats for a tenant's subscription.
Preserves the existing price (monthly, annual, or grandfathered).
Send a request to the control service to register the number of users for a tenant.
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
subscription_item = subscription["items"]["data"][0]
# Use existing price to preserve the customer's current plan
current_price_id = subscription_item.price.id
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription_item.id,
"price": current_price_id,
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"quantity": number_of_users,
}
],

View File

@@ -1,59 +1,33 @@
"""Billing API endpoints for cloud multi-tenant deployments.
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
which provides a unified API for both self-hosted and cloud deployments.
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
Current endpoints to migrate:
- GET /tenants/billing-information -> GET /admin/billing/information
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
and are NOT part of this migration - they stay here.
"""
import asyncio
import httpx
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Cache for Stripe publishable key to avoid hitting S3 on every request
_stripe_publishable_key_cache: str | None = None
_stripe_key_lock = asyncio.Lock()
@router.post("/product-gating")
def gate_product(
@@ -108,13 +82,21 @@ async def billing_information(
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
"""Create a Stripe customer portal session via the control plane."""
tenant_id = get_current_tenant_id()
return_url = f"{WEB_DOMAIN}/admin/billing"
try:
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"url": portal_url}
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@@ -122,82 +104,15 @@ async def create_customer_portal_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create subscription session")
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/stripe-publishable-key")
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
"""
Fetch the Stripe publishable key.
Priority: env var override (for testing) > S3 bucket (production).
This endpoint is public (no auth required) since publishable keys are safe to expose.
The key is cached in memory to avoid hitting S3 on every request.
"""
global _stripe_publishable_key_cache
# Fast path: return cached value without lock
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Use lock to prevent concurrent S3 requests
async with _stripe_key_lock:
# Double-check after acquiring lock (another request may have populated cache)
if _stripe_publishable_key_cache:
return StripePublishableKeyResponse(
publishable_key=_stripe_publishable_key_cache
)
# Check for env var override first (for local testing with pk_test_* keys)
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
# Fall back to S3 bucket
if not STRIPE_PUBLISHABLE_KEY_URL:
raise HTTPException(
status_code=500,
detail="Stripe publishable key is not configured",
)
try:
async with httpx.AsyncClient() as client:
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
response.raise_for_status()
key = response.text.strip()
# Validate key format
if not key.startswith("pk_"):
raise HTTPException(
status_code=500,
detail="Invalid Stripe publishable key format",
)
_stripe_publishable_key_cache = key
return StripePublishableKeyResponse(publishable_key=key)
except httpx.HTTPError:
raise HTTPException(
status_code=500,
detail="Failed to fetch Stripe publishable key",
)

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
@@ -74,12 +73,6 @@ class SubscriptionSessionResponse(BaseModel):
sessionId: str
class CreateSubscriptionSessionRequest(BaseModel):
"""Request to create a subscription checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
@@ -105,7 +98,3 @@ class PendingUserSnapshot(BaseModel):
class ApproveUserRequest(BaseModel):
email: str
class StripePublishableKeyResponse(BaseModel):
publishable_key: str

View File

@@ -65,9 +65,3 @@ def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
def is_tenant_gated(tenant_id: str) -> bool:
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))

View File

@@ -1,485 +0,0 @@
"""Proxy endpoints for billing operations.
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
for self-hosted instances to reach the control plane.
Flow:
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
Self-hosted instances call these endpoints with their license in the Authorization
header. The cloud data plane validates the license signature and forwards the
request to the control plane using JWT authentication.
Auth levels by endpoint:
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
- /claim-license: Session ID based (one-time after Stripe payment)
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
- /billing-information: Valid license required
- /license/{tenant_id}: Valid license required
- /seats/update: Valid license required
"""
from typing import Literal
import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Header
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import update_license_cache
from ee.onyx.db.license import upsert_license
from ee.onyx.server.billing.models import SeatUpdateRequest
from ee.onyx.server.billing.models import SeatUpdateResponse
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import is_license_valid
from ee.onyx.utils.license import verify_license_signature
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/proxy")
def _check_license_enforcement_enabled() -> None:
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
if not LICENSE_ENFORCEMENT_ENABLED:
raise HTTPException(
status_code=501,
detail="Proxy endpoints are only available on cloud data plane",
)
def _extract_license_from_header(
authorization: str | None,
required: bool = True,
) -> str | None:
"""Extract license data from Authorization header.
Self-hosted instances authenticate to these proxy endpoints by sending their
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
We use the Bearer scheme (RFC 6750) because:
1. It's the standard HTTP auth scheme for token-based authentication
2. The license blob is cryptographically signed (RSA), so it's self-validating
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
The license data is the base64-encoded signed blob that contains tenant_id,
seats, expiration, etc. We verify the signature to authenticate the caller.
Args:
authorization: The Authorization header value (e.g., "Bearer <license>")
required: If True, raise 401 when header is missing/invalid
Returns:
License data string (base64-encoded), or None if not required and missing
Raises:
HTTPException: 401 if required and header is missing/invalid
"""
if not authorization or not authorization.startswith("Bearer "):
if required:
raise HTTPException(
status_code=401, detail="Missing or invalid authorization header"
)
return None
return authorization.split(" ", 1)[1]
def verify_license_auth(
license_data: str,
allow_expired: bool = False,
) -> LicensePayload:
"""Verify license signature and optionally check expiry.
Args:
license_data: Base64-encoded signed license blob
allow_expired: If True, accept expired licenses (for renewal flows)
Returns:
LicensePayload if valid
Raises:
HTTPException: If license is invalid or expired (when not allowed)
"""
_check_license_enforcement_enabled()
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
if not allow_expired and not is_license_valid(payload):
raise HTTPException(status_code=401, detail="License has expired")
return payload
async def get_license_payload(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload:
"""Dependency: Require valid (non-expired) license.
Used for endpoints that require an active subscription.
"""
license_data = _extract_license_from_header(authorization, required=True)
# license_data is guaranteed non-None when required=True
assert license_data is not None
return verify_license_auth(license_data, allow_expired=False)
async def get_license_payload_allow_expired(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload:
"""Dependency: Require license with valid signature, expired OK.
Used for endpoints needed to fix payment issues (portal, renewal checkout).
"""
license_data = _extract_license_from_header(authorization, required=True)
# license_data is guaranteed non-None when required=True
assert license_data is not None
return verify_license_auth(license_data, allow_expired=True)
async def get_optional_license_payload(
authorization: str | None = Header(None, alias="Authorization"),
) -> LicensePayload | None:
"""Dependency: Optional license auth (for checkout - new customers have none).
Returns None if no license provided, otherwise validates and returns payload.
Expired licenses are allowed for renewal flows.
"""
_check_license_enforcement_enabled()
license_data = _extract_license_from_header(authorization, required=False)
if license_data is None:
return None
return verify_license_auth(license_data, allow_expired=True)
async def forward_to_control_plane(
method: str,
path: str,
body: dict | None = None,
params: dict | None = None,
) -> dict:
"""Forward a request to the control plane with proper authentication."""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":
response = await client.post(url, headers=headers, json=body)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
status_code = e.response.status_code
detail = "Control plane request failed"
try:
error_data = e.response.json()
detail = error_data.get("detail", detail)
except Exception:
pass
logger.error(f"Control plane returned {status_code}: {detail}")
raise HTTPException(status_code=status_code, detail=detail)
except httpx.RequestError:
logger.exception("Failed to connect to control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
"""Store license in database and update Redis cache.
Args:
tenant_id: The tenant ID
license_data: Base64-encoded signed license blob
"""
try:
# Verify before storing
payload = verify_license_signature(license_data)
# Store in database using the specific tenant's schema
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
upsert_license(db_session, license_data)
# Update Redis cache
update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:
logger.error(f"Failed to verify license: {e}")
raise
except Exception:
logger.exception("Failed to store license")
raise
# -----------------------------------------------------------------------------
# Endpoints
# -----------------------------------------------------------------------------
class CreateCheckoutSessionRequest(BaseModel):
billing_period: Literal["monthly", "annual"] = "monthly"
email: str | None = None
# Redirect URL after successful checkout - self-hosted passes their instance URL
redirect_url: str | None = None
# Cancel URL when user exits checkout - returns to upgrade page
cancel_url: str | None = None
class CreateCheckoutSessionResponse(BaseModel):
url: str
@router.post("/create-checkout-session")
async def proxy_create_checkout_session(
request_body: CreateCheckoutSessionRequest,
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
) -> CreateCheckoutSessionResponse:
"""Proxy checkout session creation to control plane.
Auth: Optional license (new customers don't have one yet).
If license provided, expired is OK (for renewals).
"""
# license_payload is None for new customers who don't have a license yet.
# In that case, tenant_id is omitted from the request body and the control
# plane will create a new tenant during checkout completion.
tenant_id = license_payload.tenant_id if license_payload else None
body: dict = {
"billing_period": request_body.billing_period,
}
if tenant_id:
body["tenant_id"] = tenant_id
if request_body.email:
body["email"] = request_body.email
if request_body.redirect_url:
body["redirect_url"] = request_body.redirect_url
if request_body.cancel_url:
body["cancel_url"] = request_body.cancel_url
result = await forward_to_control_plane(
"POST", "/create-checkout-session", body=body
)
return CreateCheckoutSessionResponse(url=result["url"])
class ClaimLicenseRequest(BaseModel):
session_id: str
class ClaimLicenseResponse(BaseModel):
tenant_id: str
license: str
message: str | None = None
@router.post("/claim-license")
async def proxy_claim_license(
request_body: ClaimLicenseRequest,
) -> ClaimLicenseResponse:
"""Claim a license after successful Stripe checkout.
Auth: Session ID based (one-time use after payment).
The control plane verifies the session_id is valid and unclaimed.
Returns the license to the caller. For self-hosted instances, they will
store the license locally. The cloud DP doesn't need to store it.
"""
_check_license_enforcement_enabled()
result = await forward_to_control_plane(
"POST",
"/claim-license",
body={"session_id": request_body.session_id},
)
tenant_id = result.get("tenant_id")
license_data = result.get("license")
if not tenant_id or not license_data:
logger.error(f"Control plane returned incomplete claim response: {result}")
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
return ClaimLicenseResponse(
tenant_id=tenant_id,
license=license_data,
message="License claimed successfully",
)
class CreateCustomerPortalSessionRequest(BaseModel):
return_url: str | None = None
class CreateCustomerPortalSessionResponse(BaseModel):
url: str
@router.post("/create-customer-portal-session")
async def proxy_create_customer_portal_session(
request_body: CreateCustomerPortalSessionRequest | None = None,
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
) -> CreateCustomerPortalSessionResponse:
"""Proxy customer portal session creation to control plane.
Auth: License required, expired OK (need portal to fix payment issues).
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
body: dict = {"tenant_id": tenant_id}
if request_body and request_body.return_url:
body["return_url"] = request_body.return_url
result = await forward_to_control_plane(
"POST", "/create-customer-portal-session", body=body
)
return CreateCustomerPortalSessionResponse(url=result["url"])
class BillingInformationResponse(BaseModel):
tenant_id: str
status: str | None = None
plan_type: str | None = None
seats: int | None = None
billing_period: str | None = None
current_period_start: str | None = None
current_period_end: str | None = None
cancel_at_period_end: bool = False
canceled_at: str | None = None
trial_start: str | None = None
trial_end: str | None = None
payment_method_enabled: bool = False
stripe_subscription_id: str | None = None
@router.get("/billing-information")
async def proxy_billing_information(
license_payload: LicensePayload = Depends(get_license_payload),
) -> BillingInformationResponse:
"""Proxy billing information request to control plane.
Auth: Valid (non-expired) license required.
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
result = await forward_to_control_plane(
"GET", "/billing-information", params={"tenant_id": tenant_id}
)
# Add tenant_id from license if not in response (control plane may not include it)
if "tenant_id" not in result:
result["tenant_id"] = tenant_id
return BillingInformationResponse(**result)
class LicenseFetchResponse(BaseModel):
license: str
tenant_id: str
@router.get("/license/{tenant_id}")
async def proxy_license_fetch(
tenant_id: str,
license_payload: LicensePayload = Depends(get_license_payload),
) -> LicenseFetchResponse:
"""Proxy license fetch to control plane.
Auth: Valid license required.
The tenant_id in path must match the authenticated tenant.
"""
# tenant_id is a required field in LicensePayload (Pydantic validates this),
# but we check explicitly for defense in depth
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
if tenant_id != license_payload.tenant_id:
raise HTTPException(
status_code=403,
detail="Cannot fetch license for a different tenant",
)
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
# Auto-store the refreshed license
license_data = result.get("license")
if not license_data:
logger.error(f"Control plane returned incomplete license response: {result}")
raise HTTPException(
status_code=502,
detail="Control plane returned incomplete license data",
)
fetch_and_store_license(tenant_id, license_data)
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
@router.post("/seats/update")
async def proxy_seat_update(
request_body: SeatUpdateRequest,
license_payload: LicensePayload = Depends(get_license_payload),
) -> SeatUpdateResponse:
"""Proxy seat update to control plane.
Auth: Valid (non-expired) license required.
Handles Stripe proration and license regeneration.
"""
if not license_payload.tenant_id:
raise HTTPException(status_code=401, detail="License missing tenant_id")
tenant_id = license_payload.tenant_id
result = await forward_to_control_plane(
"POST",
"/seats/update",
body={
"tenant_id": tenant_id,
"new_seat_count": request_body.new_seat_count,
},
)
return SeatUpdateResponse(
success=result.get("success", False),
current_seats=result.get("current_seats", 0),
used_seats=result.get("used_seats", 0),
message=result.get("message"),
)

View File

@@ -1,7 +1,6 @@
from fastapi_users import exceptions
from sqlalchemy import select
from ee.onyx.db.license import invalidate_license_cache
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
@@ -48,8 +47,6 @@ def get_tenant_id_for_email(email: str) -> str:
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
@@ -73,104 +70,49 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to a different tenant, they receive
an inactive mapping (invitation) to this tenant. They can accept the
invitation later to switch tenants.
Raises:
HTTPException: 402 if adding active users would exceed seat limit
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
from fastapi import HTTPException
from ee.onyx.db.license import check_seat_availability
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
unique_emails = set(emails)
if not unique_emails:
return
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
# Batch query 1: Get all existing mappings for these emails to this tenant
# Lock rows to prevent concurrent modifications
existing_mappings = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(unique_emails),
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.all()
)
emails_with_mapping = {m.email for m in existing_mappings}
# Batch query 2: Get all active mappings for these emails (any tenant)
active_mappings = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(unique_emails),
UserTenantMapping.active == True, # noqa: E712
)
.all()
)
emails_with_active_mapping = {m.email for m in active_mappings}
# Determine which users will consume a new seat.
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
# and don't consume seats until they accept. Only users without any active
# mapping will get an ACTIVE mapping and consume a seat immediately.
emails_consuming_seats = {
email
for email in unique_emails
if email not in emails_with_mapping
and email not in emails_with_active_mapping
}
# Check seat availability inside the transaction to prevent race conditions.
# Note: ALL users in unique_emails still get added below - this check only
# validates we have capacity for users who will consume seats immediately.
if emails_consuming_seats:
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
result = check_seat_availability(
tenant_session,
seats_needed=len(emails_consuming_seats),
tenant_id=tenant_id,
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
if not result.available:
raise HTTPException(
status_code=402,
detail=result.error_message or "Seat limit exceeded",
.with_for_update()
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
# Add mappings for emails that don't already have one to this tenant
for email in unique_emails:
if email in emails_with_mapping:
continue
# Create mapping: inactive if user belongs to another tenant (invitation),
# active otherwise
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=email not in emails_with_active_mapping,
)
)
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except HTTPException:
db_session.rollback()
raise
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
@@ -193,9 +135,6 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
db_session.delete(mapping)
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
@@ -210,9 +149,6 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
).delete()
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -241,9 +177,6 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
db_session.add(new_mapping)
db_session.commit()
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
@@ -262,42 +195,19 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
Raises:
HTTPException: 402 if accepting would exceed seat limit
"""
from fastapi import HTTPException
from ee.onyx.db.license import check_seat_availability
from onyx.db.engine.sql_engine import get_session_with_tenant
with get_session_with_shared_schema() as db_session:
try:
# Lock the user's mappings first to prevent race conditions.
# This ensures no concurrent request can modify this user's mappings
# while we check seats and activate.
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.with_for_update()
.first()
)
# Check seat availability within the same logical operation.
# Note: This queries fresh data from DB, not cache.
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
result = check_seat_availability(
tenant_session, seats_needed=1, tenant_id=tenant_id
)
if not result.available:
raise HTTPException(
status_code=402,
detail=result.error_message or "Seat limit exceeded",
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
@@ -327,9 +237,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
# Invalidate license cache so used_seats reflects the new count
invalidate_license_cache(tenant_id)
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
@@ -390,41 +297,16 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant.
A user counts toward the seat count if:
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
2. AND the User is active (User.is_active == True)
TODO: Exclude API key dummy users from seat counting. API keys create
users with emails like `__DANSWER_API_KEY_*` that should not count toward
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
Get the number of active users for this tenant
"""
from onyx.db.models import User
# First get all emails with active mappings to this tenant
with get_session_with_shared_schema() as db_session:
active_mapping_emails = (
db_session.query(UserTenantMapping.email)
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.all()
)
emails = [email for (email,) in active_mapping_emails]
if not emails:
return 0
# Now count how many of those users are actually active in the tenant's User table
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user_count = (
db_session.query(User)
.filter(
User.email.in_(emails), # type: ignore
User.is_active == True, # type: ignore # noqa: E712
)
.count()
)

View File

@@ -9,7 +9,6 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.token_limit import fetch_all_user_token_rate_limits
@@ -17,6 +16,7 @@ from onyx.db.token_limit import insert_user_token_rate_limit
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
from onyx.server.utils import PUBLIC_API_TAGS
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)

View File

@@ -1,5 +1,8 @@
"""EE Usage limits - trial detection via billing information."""
from datetime import datetime
from datetime import timezone
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
@@ -28,7 +31,13 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
return True
if isinstance(billing_info, BillingInformation):
return billing_info.status == "trialing"
# Check if trial is active
if billing_info.trial_end is not None:
now = datetime.now(timezone.utc)
# Trial active if trial_end is in the future
# and subscription status indicates trialing
if billing_info.trial_end > now and billing_info.status == "trialing":
return True
return False

View File

@@ -18,10 +18,10 @@ from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.server.utils import PUBLIC_API_TAGS
from onyx.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -5,7 +5,6 @@ import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
@@ -20,27 +19,21 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
# Path to the license public key file
_LICENSE_PUBLIC_KEY_PATH = (
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
)
# RSA-4096 Public Key for license verification
# Load from environment variable - key is generated on the control plane
# In production, inject via Kubernetes secrets or secrets manager
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
def _get_public_key() -> RSAPublicKey:
"""Load the public key from file, with env var override."""
# Allow env var override for flexibility
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
if not key_pem:
# Read from file
if not _LICENSE_PUBLIC_KEY_PATH.exists():
raise ValueError(
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
"License verification requires the control plane public key."
)
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
key = serialization.load_pem_public_key(key_pem.encode())
"""Load the public key from environment variable."""
if not LICENSE_PUBLIC_KEY_PEM:
raise ValueError(
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
"License verification requires the control plane public key."
)
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
if not isinstance(key, RSAPublicKey):
raise ValueError("Expected RSA public key")
return key
@@ -60,21 +53,17 @@ def verify_license_signature(license_data: str) -> LicensePayload:
ValueError: If license data is invalid or signature verification fails
"""
try:
# Decode the license data
decoded = json.loads(base64.b64decode(license_data))
# Parse into LicenseData to validate structure
license_obj = LicenseData(**decoded)
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
# not re-serialized through Pydantic. Pydantic may format fields differently
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
original_payload = decoded.get("payload", {})
payload_json = json.dumps(original_payload, sort_keys=True)
payload_json = json.dumps(
license_obj.payload.model_dump(mode="json"), sort_keys=True
)
signature_bytes = base64.b64decode(license_obj.signature)
# Verify signature using PSS padding (modern standard)
public_key = _get_public_key()
public_key.verify(
signature_bytes,
payload_json.encode(),
@@ -88,18 +77,16 @@ def verify_license_signature(license_data: str) -> LicensePayload:
return license_obj.payload
except InvalidSignature:
logger.error("[verify_license] FAILED: Signature verification failed")
logger.error("License signature verification failed")
raise ValueError("Invalid license signature")
except json.JSONDecodeError as e:
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
except json.JSONDecodeError:
logger.error("Failed to decode license JSON")
raise ValueError("Invalid license format: not valid JSON")
except (ValueError, KeyError, TypeError) as e:
logger.error(
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
)
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
logger.error(f"License data validation error: {type(e).__name__}")
raise ValueError(f"Invalid license format: {type(e).__name__}")
except Exception:
logger.exception("[verify_license] FAILED: Unexpected error")
logger.exception("Unexpected error during license verification")
raise ValueError("License verification failed: unexpected error")

View File

@@ -6,7 +6,6 @@ from posthog import Posthog
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
from ee.onyx.configs.app_configs import POSTHOG_HOST
from onyx.utils.logger import setup_logger
@@ -21,7 +20,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
posthog = Posthog(
project_api_key=POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
debug=True,
on_error=posthog_on_error,
)
@@ -34,7 +33,7 @@ if MARKETING_POSTHOG_API_KEY:
marketing_posthog = Posthog(
project_api_key=MARKETING_POSTHOG_API_KEY,
host=POSTHOG_HOST,
debug=POSTHOG_DEBUG_LOGS_ENABLED,
debug=True,
on_error=posthog_on_error,
)

View File

@@ -1,14 +0,0 @@
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
-----END PUBLIC KEY-----

View File

@@ -97,14 +97,10 @@ def get_access_for_documents(
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
"""Returns a list of ACL entries that the user has access to.
This is meant to be used downstream to filter out documents that the user
does not have access to. The user should have access to a document if at
least one entry in the document's ACL matches one entry in the returned set.
NOTE: These strings must be formatted in the same way as the output of
DocumentAccess::to_acl.
"""Returns a list of ACL entries that the user has access to. This is meant to be
used downstream to filter out documents that the user does not have access to. The
user should have access to a document if at least one entry in the document's ACL
matches one entry in the returned set.
"""
if user:
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}

View File

@@ -105,8 +105,6 @@ class DocExternalAccess:
)
# TODO(andrei): First refactor this into a pydantic model, then get rid of
# duplicate fields.
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
@@ -125,11 +123,9 @@ class DocumentAccess(ExternalAccess):
)
def to_acl(self) -> set[str]:
"""Converts the access state to a set of formatted ACL strings.
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
NOTE: When querying for documents, the supplied ACL filter strings must
be formatted in the same way as this function.
"""
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:

View File

@@ -11,7 +11,6 @@ from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Tuple
@@ -1457,9 +1456,6 @@ def get_default_admin_user_emails_() -> list[str]:
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
STATE_TOKEN_LIFETIME_SECONDS = 3600
CSRF_TOKEN_KEY = "csrftoken"
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
class OAuth2AuthorizeResponse(BaseModel):
@@ -1467,19 +1463,13 @@ class OAuth2AuthorizeResponse(BaseModel):
def generate_state_token(
data: Dict[str, str],
secret: SecretType,
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
def generate_csrf_token() -> str:
return secrets.token_urlsafe(32)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_onyx_oauth_router(
oauth_client: BaseOAuth2,
@@ -1508,13 +1498,6 @@ def get_oauth_router(
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
*,
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
csrf_token_cookie_path: str = "/",
csrf_token_cookie_domain: Optional[str] = None,
csrf_token_cookie_secure: Optional[bool] = None,
csrf_token_cookie_httponly: bool = True,
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
@@ -1531,9 +1514,6 @@ def get_oauth_router(
route_name=callback_route_name,
)
if csrf_token_cookie_secure is None:
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
@@ -1541,10 +1521,8 @@ def get_oauth_router(
)
async def authorize(
request: Request,
response: Response,
redirect: bool = Query(False),
scopes: List[str] = Query(None),
) -> Response | OAuth2AuthorizeResponse:
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
@@ -1554,11 +1532,9 @@ def get_oauth_router(
next_url = request.query_params.get("next", "/")
csrf_token = generate_csrf_token()
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
CSRF_TOKEN_KEY: csrf_token,
}
state = generate_state_token(state_data, state_secret)
@@ -1575,31 +1551,6 @@ def get_oauth_router(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
if redirect:
redirect_response = RedirectResponse(authorization_url, status_code=302)
redirect_response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return redirect_response
response.set_cookie(
key=csrf_token_cookie_name,
value=csrf_token,
max_age=STATE_TOKEN_LIFETIME_SECONDS,
path=csrf_token_cookie_path,
domain=csrf_token_cookie_domain,
secure=csrf_token_cookie_secure,
httponly=csrf_token_cookie_httponly,
samesite=csrf_token_cookie_samesite,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@log_function_time(print_only=True)
@@ -1649,33 +1600,7 @@ def get_oauth_router(
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
),
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(
ErrorCode,
"ACCESS_TOKEN_ALREADY_EXPIRED",
"ACCESS_TOKEN_ALREADY_EXPIRED",
),
)
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
if (
not cookie_csrf_token
or not state_csrf_token
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)

View File

@@ -1,98 +0,0 @@
# Overview of Onyx Background Jobs
The background jobs take care of:
1. Pulling/Indexing documents (from connectors)
2. Updating document metadata (from connectors)
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat)
5. Reporting metrics on things like queue length for monitoring purposes
## Worker → Queue Mapping
| Worker | File | Queues |
|--------|------|--------|
| Primary | `apps/primary.py` | `celery` |
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
| Monitoring | `apps/monitoring.py` | `monitoring` |
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
## Non-Worker Apps
| App | File | Purpose |
|-----|------|---------|
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
### Shared Module
`app_base.py` provides:
- `TenantAwareTask` - Base task class that sets tenant context
- Signal handlers for logging, cleanup, and lifecycle events
- Readiness probes and health checks
## Worker Details
### Primary (Coordinator and task dispatcher)
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
On startup:
- waits for redis, postgres, document index to all be healthy
- acquires the singleton lock
- cleans all the redis states associated with background jobs
- mark orphaned index attempts failed
Then it cycles through its tasks as scheduled by Celery Beat:
| Task | Frequency | Description |
|------|-----------|-------------|
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
See supervisord.conf for watchdog config.
### Light
Fast and short living tasks that are not resource intensive. High concurrency:
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
Tasks it handles:
- Syncs access/permissions, document sets, boosts, hidden state
- Deletes documents that are marked for deletion in Postgres
- Cleanup of checkpoints and index attempts
### Heavy
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
Generates CSV exports which may take a long time with significant data in Postgres.
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
### Docprocessing, Docfetching, User File Processing
Docprocessing and Docfetching are for indexing documents:
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
User Files come from uploads directly via the input bar
### Monitoring
Observability and metrics collections:
- Queue lengths, connector success/failure, lconnector latencies
- Memory of supervisor managed processes (workers, beat, slack)
- Cloud and multitenant specific monitorings

View File

@@ -26,13 +26,10 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
from onyx.document_index.opensearch.client import (
wait_for_opensearch_with_timeout,
)
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
@@ -43,7 +40,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.tracing.braintrust_tracing import setup_braintrust_if_creds_available
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import PlainFormatter
@@ -238,9 +234,6 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)
# Initialize Braintrust tracing in workers if credentials are available.
setup_braintrust_if_creds_available()
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
@@ -523,17 +516,14 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a timeout.
Raises WorkerShutdown if the timeout is reached."""
if ENABLE_OPENSEARCH_FOR_ONYX:
return
if not wait_for_vespa_with_timeout():
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
if not wait_for_opensearch_with_timeout():
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
# File for validating worker liveness
class LivenessProbe(bootsteps.StartStopStep):

View File

@@ -121,9 +121,9 @@ celery_app.autodiscover_tasks(
[
# Original background worker tasks
"onyx.background.celery.tasks.pruning",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.user_file_processing",
"onyx.background.celery.tasks.llm_model_update",
# Light worker tasks
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
@@ -133,7 +133,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.docprocessing",
# Docfetching worker tasks
"onyx.background.celery.tasks.docfetching",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -98,7 +98,5 @@ for bootstep in base_bootsteps:
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",
# Sandbox tasks (file sync, cleanup)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -0,0 +1,109 @@
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
pool_size = cast(int, sender.concurrency) # type: ignore
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.kg_processing",
]
)

View File

@@ -116,7 +116,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.docprocessing",
# Sandbox cleanup tasks (isolated in build feature)
"onyx.server.features.build.sandbox.tasks",
]
)

View File

@@ -323,6 +323,7 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.shared",
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.llm_model_update",
"onyx.background.celery.tasks.kg_processing",
"onyx.background.celery.tasks.user_file_processing",
]
)

View File

@@ -21,7 +21,6 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -33,16 +32,10 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: (
Iterator[list[Document | HierarchyNode]]
| Iterator[list[SlimDocument | HierarchyNode]]
),
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
for doc in doc_list
}
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(

Some files were not shown because too many files have changed in this diff Show More