mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
6 Commits
search-2
...
thread_sen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5848975679 | ||
|
|
dcc330010e | ||
|
|
d0f5f1f5ae | ||
|
|
3e475993ff | ||
|
|
7c2b5fa822 | ||
|
|
409cfdc788 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -8,5 +8,4 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
644
.github/workflows/deployment.yml
vendored
644
.github/workflows/deployment.yml
vendored
File diff suppressed because it is too large
Load Diff
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
1
.github/workflows/helm-chart-releases.yml
vendored
1
.github/workflows/helm-chart-releases.yml
vendored
@@ -29,7 +29,6 @@ jobs:
|
||||
run: |
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -45,9 +45,6 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -128,13 +125,11 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
cache \
|
||||
index \
|
||||
opensearch \
|
||||
code-interpreter
|
||||
|
||||
- name: Run migrations
|
||||
@@ -163,7 +158,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
@@ -177,7 +172,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
8
.github/workflows/pr-helm-chart-testing.yml
vendored
8
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -88,7 +88,6 @@ jobs:
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
@@ -181,11 +180,6 @@ jobs:
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
# Note that opensearch.enabled is true whereas others in this install
|
||||
# are false. There is some work that needs to be done to get this
|
||||
# entire step working in CI, enabling opensearch here is a small step
|
||||
# in that direction. If this is causing issues, disabling it in this
|
||||
# step should be ok in the short term.
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
@@ -193,8 +187,6 @@ jobs:
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
|
||||
13
.github/workflows/pr-integration-tests.yml
vendored
13
.github/workflows/pr-integration-tests.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -310,9 +310,8 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -439,7 +438,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -568,7 +567,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
10
.github/workflows/pr-mit-integration-tests.yml
vendored
10
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -95,7 +95,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -214,7 +214,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -301,7 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
10
.github/workflows/pr-playwright-tests.yml
vendored
10
.github/workflows/pr-playwright-tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,9 +50,8 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
140
.github/workflows/pr-python-model-tests.yml
vendored
140
.github/workflows/pr-python-model-tests.yml
vendored
@@ -5,6 +5,11 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -26,11 +31,7 @@ env:
|
||||
jobs:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
@@ -42,87 +43,108 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
load: true
|
||||
targets: model-server
|
||||
set: |
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Start Docker containers
|
||||
id: start_docker
|
||||
env:
|
||||
IMAGE_TAG: model-server-${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
up -d --wait \
|
||||
inference_model_server
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: model-check
|
||||
title: "🚨 Scheduled Model Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,8 +1,5 @@
|
||||
# editors
|
||||
.vscode
|
||||
!/.vscode/env_template.txt
|
||||
!/.vscode/launch.json
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
|
||||
@@ -24,7 +21,6 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,6 +11,7 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
@@ -66,8 +67,7 @@ repos:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -75,13 +75,6 @@ repos:
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
name: Check for added large files
|
||||
args: ["--maxkb=1500"]
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
@@ -154,22 +147,6 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: npm-install
|
||||
name: npm install
|
||||
description: "Automatically run 'npm install' after a checkout, pull or rebase"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --no-save'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
stages: [post-checkout, post-merge, post-rewrite]
|
||||
- id: npm-install-check
|
||||
name: npm install --package-lock-only
|
||||
description: "Check the 'web/package-lock.json' is updated"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --package-lock-only'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -17,6 +17,12 @@ LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
|
||||
155
.vscode/launch.json → .vscode/launch.template.jsonc
vendored
155
.vscode/launch.json → .vscode/launch.template.jsonc
vendored
@@ -1,3 +1,5 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
@@ -22,7 +24,7 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery heavy",
|
||||
"Celery background",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat"
|
||||
@@ -149,24 +151,6 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -415,6 +399,7 @@
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
@@ -445,6 +430,7 @@
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
@@ -593,137 +579,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Build Sandbox Templates",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "onyx.server.features.build.sandbox.build_templates",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
},
|
||||
"consoleTitle": "Build Sandbox Templates"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "4",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--clean",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Create database snapshot",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"dump",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore database snapshot (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--clean",
|
||||
"--yes",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Upgrade database to head revision",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"upgrade"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
263
CONTRIBUTING.md
263
CONTRIBUTING.md
@@ -1,31 +1,262 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
## Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
## 💃 Guidelines
|
||||
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
### Contribution Opportunities
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
|
||||
To contribute, please follow the
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
## Get Started 🚀
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
### Release Process
|
||||
|
||||
See you there!
|
||||
|
||||
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
|
||||
@@ -7,6 +7,8 @@ This guide explains how to set up and use VSCode's debugging capabilities with t
|
||||
1. **Environment Setup**:
|
||||
- Copy `.vscode/env_template.txt` to `.vscode/.env`
|
||||
- Fill in the necessary environment variables in `.vscode/.env`
|
||||
2. **launch.json**:
|
||||
- Copy `.vscode/launch.template.jsonc` to `.vscode/launch.json`
|
||||
|
||||
## Using the Debugger
|
||||
|
||||
@@ -16,8 +16,3 @@ dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
# Craft: demo_data directory should be unzipped at container startup, not copied
|
||||
**/demo_data/
|
||||
# Craft: templates/outputs/venv is created at container startup
|
||||
**/templates/outputs/venv
|
||||
|
||||
@@ -37,6 +37,10 @@ CVE-2023-50868
|
||||
CVE-2023-52425
|
||||
CVE-2024-28757
|
||||
|
||||
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
|
||||
# No impact in our settings
|
||||
CVE-2023-7104
|
||||
|
||||
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
|
||||
# Accept the risk
|
||||
CVE-2023-25193
|
||||
|
||||
@@ -7,10 +7,6 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Build argument for Craft support (disabled by default)
|
||||
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
|
||||
ARG ENABLE_CRAFT=false
|
||||
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true" \
|
||||
@@ -50,23 +46,7 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Conditionally install Node.js 20 for Craft (required for Next.js)
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing Node.js 20 for Craft support..." && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# Conditionally install opencode CLI for Craft agent functionality
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
# TODO: download a specific, versioned release of the opencode CLI
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing opencode CLI for Craft support..." && \
|
||||
curl -fsSL https://opencode.ai/install | bash; \
|
||||
fi
|
||||
ENV PATH="/root/.opencode/bin:${PATH}"
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -111,8 +91,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
@@ -139,15 +119,7 @@ COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
# This pre-bakes demo data, Python venv, and npm dependencies into the image
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Running Craft template setup at build time..." && \
|
||||
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
|
||||
fi
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
@@ -225,6 +225,7 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -308,7 +309,6 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,7 +346,6 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
"""single onyx craft migration
|
||||
|
||||
Consolidates all buildmode/onyx craft tables into a single migration.
|
||||
|
||||
Tables created:
|
||||
- build_session: User build sessions with status tracking
|
||||
- sandbox: User-owned containerized environments (one per user)
|
||||
- artifact: Build output files (web apps, documents, images)
|
||||
- snapshot: Sandbox filesystem snapshots
|
||||
- build_message: Conversation messages for build sessions
|
||||
|
||||
Existing table modified:
|
||||
- connector_credential_pair: Added processing_mode column
|
||||
|
||||
Revision ID: 2020d417ec84
|
||||
Revises: 41fa44bef321
|
||||
Create Date: 2026-01-26 14:43:54.641405
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2020d417ec84"
|
||||
down_revision = "41fa44bef321"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# ENUMS
|
||||
# ==========================================================================
|
||||
|
||||
# Build session status enum
|
||||
build_session_status_enum = sa.Enum(
|
||||
"active",
|
||||
"idle",
|
||||
name="buildsessionstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Sandbox status enum
|
||||
sandbox_status_enum = sa.Enum(
|
||||
"provisioning",
|
||||
"running",
|
||||
"idle",
|
||||
"sleeping",
|
||||
"terminated",
|
||||
"failed",
|
||||
name="sandboxstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Artifact type enum
|
||||
artifact_type_enum = sa.Enum(
|
||||
"web_app",
|
||||
"pptx",
|
||||
"docx",
|
||||
"markdown",
|
||||
"excel",
|
||||
"image",
|
||||
name="artifacttype",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_session",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
build_session_status_enum,
|
||||
nullable=False,
|
||||
server_default="active",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("nextjs_port", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_session_user_created",
|
||||
"build_session",
|
||||
["user_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_build_session_status",
|
||||
"build_session",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE (user-owned, one per user)
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"sandbox",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("container_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sandbox_status_enum,
|
||||
nullable=False,
|
||||
server_default="provisioning",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_sandbox_status",
|
||||
"sandbox",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_sandbox_container_id",
|
||||
"sandbox",
|
||||
["container_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"artifact",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", artifact_type_enum, nullable=False),
|
||||
sa.Column("path", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_artifact_session_created",
|
||||
"artifact",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_artifact_type",
|
||||
"artifact",
|
||||
["type"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"snapshot",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("storage_path", sa.String(), nullable=False),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_snapshot_session_created",
|
||||
"snapshot",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"turn_index",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"SYSTEM",
|
||||
"USER",
|
||||
"ASSISTANT",
|
||||
"DANSWER",
|
||||
name="messagetype",
|
||||
create_type=False,
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"message_metadata",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_message_session_turn",
|
||||
"build_message",
|
||||
["session_id", "turn_index", sa.text("created_at ASC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"processing_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="regular",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_column("connector_credential_pair", "processing_mode")
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_message_session_turn", table_name="build_message")
|
||||
op.drop_table("build_message")
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
|
||||
op.drop_table("snapshot")
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_artifact_type", table_name="artifact")
|
||||
op.drop_index("ix_artifact_session_created", table_name="artifact")
|
||||
op.drop_table("artifact")
|
||||
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
|
||||
op.drop_index("ix_sandbox_status", table_name="sandbox")
|
||||
op.drop_table("sandbox")
|
||||
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_session_status", table_name="build_session")
|
||||
op.drop_index("ix_build_session_user_created", table_name="build_session")
|
||||
op.drop_table("build_session")
|
||||
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""add_unique_constraint_to_inputprompt_prompt_user_id
|
||||
|
||||
Revision ID: 2c2430828bdf
|
||||
Revises: fb80bdd256de
|
||||
Create Date: 2026-01-20 16:01:54.314805
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2c2430828bdf"
|
||||
down_revision = "fb80bdd256de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique constraint on (prompt, user_id) for user-owned prompts
|
||||
# This ensures each user can only have one shortcut with a given name
|
||||
op.create_unique_constraint(
|
||||
"uq_inputprompt_prompt_user_id",
|
||||
"inputprompt",
|
||||
["prompt", "user_id"],
|
||||
)
|
||||
|
||||
# Create partial unique index for public prompts (where user_id IS NULL)
|
||||
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
|
||||
# so we need a partial index to ensure public prompt names are also unique
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
|
||||
ON inputprompt (prompt)
|
||||
WHERE user_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
|
||||
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove default prompt shortcuts
|
||||
|
||||
Revision ID: 41fa44bef321
|
||||
Revises: 2c2430828bdf
|
||||
Create Date: 2025-01-21
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41fa44bef321"
|
||||
down_revision = "2c2430828bdf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete any user associations for the default prompts first (foreign key constraint)
|
||||
op.execute(
|
||||
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
|
||||
)
|
||||
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
|
||||
op.execute("DELETE FROM inputprompt WHERE id < 0")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't restore the default prompts on downgrade
|
||||
pass
|
||||
@@ -85,122 +85,103 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -209,148 +190,191 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
"""make processing mode default all caps
|
||||
|
||||
Revision ID: 72aa7de2e5cf
|
||||
Revises: 2020d417ec84
|
||||
Create Date: 2026-01-26 18:58:47.705253
|
||||
|
||||
This migration fixes the ProcessingMode enum value mismatch:
|
||||
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
|
||||
- The original migration stored lowercase VALUES ('regular', 'file_system')
|
||||
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
|
||||
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "72aa7de2e5cf"
|
||||
down_revision = "2020d417ec84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Convert existing lowercase values to uppercase to match enum member names
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
|
||||
"WHERE processing_mode = 'regular'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
|
||||
"WHERE processing_mode = 'file_system'"
|
||||
)
|
||||
|
||||
# Update the server default to use uppercase
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"processing_mode",
|
||||
server_default="REGULAR",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# State prior to this was broken, so we don't want to revert back to it
|
||||
pass
|
||||
@@ -1,47 +0,0 @@
|
||||
"""add_search_query_table
|
||||
|
||||
Revision ID: 73e9983e5091
|
||||
Revises: d1b637d7050a
|
||||
Create Date: 2026-01-14 14:16:52.837489
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "73e9983e5091"
|
||||
down_revision = "d1b637d7050a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"search_query",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index("ix_search_query_user_id", "search_query", ["user_id"])
|
||||
op.create_index("ix_search_query_created_at", "search_query", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_search_query_created_at", table_name="search_query")
|
||||
op.drop_index("ix_search_query_user_id", table_name="search_query")
|
||||
op.drop_table("search_query")
|
||||
@@ -10,7 +10,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting, SearchType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
"""hierarchy_nodes_v1
|
||||
|
||||
Revision ID: 81c22b1e2e78
|
||||
Revises: 72aa7de2e5cf
|
||||
Create Date: 2026-01-13 18:10:01.021451
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "81c22b1e2e78"
|
||||
down_revision = "72aa7de2e5cf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Human-readable display names for each source
|
||||
SOURCE_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ingestion_api": "Ingestion API",
|
||||
"slack": "Slack",
|
||||
"web": "Web",
|
||||
"google_drive": "Google Drive",
|
||||
"gmail": "Gmail",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"gitbook": "GitBook",
|
||||
"gitlab": "GitLab",
|
||||
"guru": "Guru",
|
||||
"bookstack": "BookStack",
|
||||
"outline": "Outline",
|
||||
"confluence": "Confluence",
|
||||
"jira": "Jira",
|
||||
"slab": "Slab",
|
||||
"productboard": "Productboard",
|
||||
"file": "File",
|
||||
"coda": "Coda",
|
||||
"notion": "Notion",
|
||||
"zulip": "Zulip",
|
||||
"linear": "Linear",
|
||||
"hubspot": "HubSpot",
|
||||
"document360": "Document360",
|
||||
"gong": "Gong",
|
||||
"google_sites": "Google Sites",
|
||||
"zendesk": "Zendesk",
|
||||
"loopio": "Loopio",
|
||||
"dropbox": "Dropbox",
|
||||
"sharepoint": "SharePoint",
|
||||
"teams": "Teams",
|
||||
"salesforce": "Salesforce",
|
||||
"discourse": "Discourse",
|
||||
"axero": "Axero",
|
||||
"clickup": "ClickUp",
|
||||
"mediawiki": "MediaWiki",
|
||||
"wikipedia": "Wikipedia",
|
||||
"asana": "Asana",
|
||||
"s3": "S3",
|
||||
"r2": "R2",
|
||||
"google_cloud_storage": "Google Cloud Storage",
|
||||
"oci_storage": "OCI Storage",
|
||||
"xenforo": "XenForo",
|
||||
"not_applicable": "Not Applicable",
|
||||
"discord": "Discord",
|
||||
"freshdesk": "Freshdesk",
|
||||
"fireflies": "Fireflies",
|
||||
"egnyte": "Egnyte",
|
||||
"airtable": "Airtable",
|
||||
"highspot": "Highspot",
|
||||
"drupal_wiki": "Drupal Wiki",
|
||||
"imap": "IMAP",
|
||||
"bitbucket": "Bitbucket",
|
||||
"testrail": "TestRail",
|
||||
"mock_connector": "Mock Connector",
|
||||
"user_file": "User File",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create hierarchy_node table
|
||||
op.create_table(
|
||||
"hierarchy_node",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("raw_node_id", sa.String(), nullable=False),
|
||||
sa.Column("display_name", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("node_type", sa.String(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("parent_id", sa.Integer(), nullable=True),
|
||||
# Permission fields - same pattern as Document table
|
||||
sa.Column(
|
||||
"external_user_emails",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"external_user_group_ids",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
# When document is deleted, just unlink (node can exist without document)
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
|
||||
# When parent node is deleted, orphan children (cleanup via pruning)
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
|
||||
)
|
||||
|
||||
# Add partial unique index to ensure only one SOURCE-type node per source
|
||||
# This prevents duplicate source root nodes from being created
|
||||
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
|
||||
ON hierarchy_node (source)
|
||||
WHERE node_type = 'SOURCE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create hierarchy_fetch_attempt table
|
||||
op.create_table(
|
||||
"hierarchy_fetch_attempt",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("full_exception_trace", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created",
|
||||
"hierarchy_fetch_attempt",
|
||||
["time_created"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair",
|
||||
"hierarchy_fetch_attempt",
|
||||
["connector_credential_pair_id"],
|
||||
)
|
||||
|
||||
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
|
||||
# We insert these so every existing document can have a parent hierarchy node
|
||||
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
|
||||
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
|
||||
# SOURCE nodes are always public since they're just categorical roots.
|
||||
for source in DocumentSource:
|
||||
source_name = (
|
||||
source.name
|
||||
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
|
||||
source_value = source.value # e.g., 'google_drive' - the raw_node_id
|
||||
display_name = SOURCE_DISPLAY_NAMES.get(
|
||||
source_value, source_value.replace("_", " ").title()
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
|
||||
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
|
||||
ON CONFLICT (raw_node_id, source) DO NOTHING
|
||||
"""
|
||||
).bindparams(
|
||||
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
|
||||
display_name=display_name,
|
||||
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Add parent_hierarchy_node_id column to document table
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# When hierarchy node is deleted, just unlink the document (SET NULL)
|
||||
op.create_foreign_key(
|
||||
"fk_document_parent_hierarchy_node",
|
||||
"document",
|
||||
"hierarchy_node",
|
||||
["parent_hierarchy_node_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_document_parent_hierarchy_node_id",
|
||||
"document",
|
||||
["parent_hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
|
||||
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
|
||||
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
|
||||
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE document d
|
||||
SET parent_hierarchy_node_id = hn.id
|
||||
FROM (
|
||||
-- Get the source for each document (pick MIN connector_id for determinism)
|
||||
SELECT DISTINCT ON (dbcc.id)
|
||||
dbcc.id as doc_id,
|
||||
c.source as source
|
||||
FROM document_by_connector_credential_pair dbcc
|
||||
JOIN connector c ON dbcc.connector_id = c.id
|
||||
ORDER BY dbcc.id, dbcc.connector_id
|
||||
) doc_source
|
||||
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
|
||||
WHERE d.id = doc_source.doc_id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create the persona__hierarchy_node association table
|
||||
op.create_table(
|
||||
"persona__hierarchy_node",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups
|
||||
op.create_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
"persona__hierarchy_node",
|
||||
["hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# Create the persona__document association table for attaching individual
|
||||
# documents directly to assistants
|
||||
op.create_table(
|
||||
"persona__document",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "document_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups by document_id
|
||||
op.create_index(
|
||||
"ix_persona__document_document_id",
|
||||
"persona__document",
|
||||
["document_id"],
|
||||
)
|
||||
|
||||
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove last_time_hierarchy_fetch from connector_credential_pair
|
||||
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
|
||||
|
||||
# Drop persona__document table
|
||||
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
|
||||
op.drop_table("persona__document")
|
||||
|
||||
# Drop persona__hierarchy_node table
|
||||
op.drop_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
table_name="persona__hierarchy_node",
|
||||
)
|
||||
op.drop_table("persona__hierarchy_node")
|
||||
|
||||
# Remove parent_hierarchy_node_id from document
|
||||
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
|
||||
op.drop_constraint(
|
||||
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("document", "parent_hierarchy_node_id")
|
||||
|
||||
# Drop hierarchy_fetch_attempt table
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_table("hierarchy_fetch_attempt")
|
||||
|
||||
# Drop hierarchy_node table
|
||||
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
|
||||
op.drop_table("hierarchy_node")
|
||||
@@ -1,49 +0,0 @@
|
||||
"""notifications constraint, sort index, and cleanup old notifications
|
||||
|
||||
Revision ID: 8405ca81cc83
|
||||
Revises: a3c1a7904cd0
|
||||
Create Date: 2026-01-07 16:43:44.855156
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8405ca81cc83"
|
||||
down_revision = "a3c1a7904cd0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique index for notification deduplication.
|
||||
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
|
||||
#
|
||||
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
|
||||
"""
|
||||
)
|
||||
|
||||
# Create index for efficient notification sorting by user
|
||||
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
|
||||
ON notification (user_id, dismissed, first_shown DESC)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")
|
||||
@@ -1,116 +0,0 @@
|
||||
"""Add Discord bot tables
|
||||
|
||||
Revision ID: 8b5ce697290e
|
||||
Revises: a1b2c3d4e5f7
|
||||
Create Date: 2025-01-14
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8b5ce697290e"
|
||||
down_revision = "a1b2c3d4e5f7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# DiscordBotConfig (singleton table - one per tenant)
|
||||
op.create_table(
|
||||
"discord_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
primary_key=True,
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
|
||||
)
|
||||
|
||||
# DiscordGuildConfig
|
||||
op.create_table(
|
||||
"discord_guild_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
|
||||
sa.Column("guild_name", sa.String(), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"default_persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# DiscordChannelConfig
|
||||
op.create_table(
|
||||
"discord_channel_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"guild_config_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_type",
|
||||
sa.String(20),
|
||||
server_default=sa.text("'text'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_private",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_only_mode",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"require_bot_invocation",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"persona_override_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# Unique constraint: one config per channel per guild
|
||||
op.create_unique_constraint(
|
||||
"uq_discord_channel_guild_channel",
|
||||
"discord_channel_config",
|
||||
["guild_config_id", "channel_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("discord_channel_config")
|
||||
op.drop_table("discord_guild_config")
|
||||
op.drop_table("discord_bot_config")
|
||||
@@ -42,13 +42,20 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
"""drop agent_search_metrics table
|
||||
|
||||
Revision ID: a1b2c3d4e5f7
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f7"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -7,6 +7,7 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -18,7 +19,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": "ResearchAgent",
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,66 +70,80 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -1,86 +0,0 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""add chat_background to user
|
||||
|
||||
Revision ID: fb80bdd256de
|
||||
Revises: 8b5ce697290e
|
||||
Create Date: 2026-01-16 16:15:59.222617
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fb80bdd256de"
|
||||
down_revision = "8b5ce697290e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chat_background",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chat_background")
|
||||
@@ -109,6 +109,7 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
@@ -122,23 +123,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
# Used when MULTI_TENANT=false (self-hosted mode)
|
||||
CLOUD_DATA_PLANE_URL = os.environ.get(
|
||||
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
|
||||
)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Constants for license enforcement.
|
||||
|
||||
This file is the single source of truth for:
|
||||
1. Paths that bypass license enforcement (always accessible)
|
||||
2. Paths that require an EE license (EE-only features)
|
||||
|
||||
Import these constants in both production code and tests to ensure consistency.
|
||||
"""
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /billing - Unified billing API
|
||||
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
|
||||
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
|
||||
# /manage/users, /users - User management (needed for seat limit resolution)
|
||||
# /notifications - Needed for UI to load properly
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
# Billing endpoints (unified API for both MT and self-hosted)
|
||||
"/billing",
|
||||
"/admin/billing",
|
||||
# Proxy endpoints for self-hosted billing (no tenant context)
|
||||
"/proxy",
|
||||
# Legacy tenant billing endpoints (kept for backwards compatibility)
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
# User management - needed to remove users when seat limit exceeded
|
||||
"/manage/users",
|
||||
"/manage/admin/users",
|
||||
"/manage/admin/valid-domains",
|
||||
"/manage/admin/deactivate-user",
|
||||
"/manage/admin/delete-user",
|
||||
"/users",
|
||||
# Notifications - needed for UI to load properly
|
||||
"/notifications",
|
||||
}
|
||||
)
|
||||
|
||||
# EE-only paths that require a valid license.
|
||||
# Users without a license (community edition) cannot access these.
|
||||
# These are blocked even when user has never subscribed (no license).
|
||||
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
# User groups and access control
|
||||
"/manage/admin/user-group",
|
||||
# Analytics and reporting
|
||||
"/analytics",
|
||||
# Query history (admin chat session endpoints)
|
||||
"/admin/chat-sessions",
|
||||
"/admin/chat-session-history",
|
||||
"/admin/query-history",
|
||||
# Usage reporting/export
|
||||
"/admin/usage-report",
|
||||
# Standard answers (canned responses)
|
||||
"/manage/admin/standard-answer",
|
||||
# Token rate limits
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
}
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import NamedTuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -10,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -25,13 +23,6 @@ LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
class SeatAvailabilityResult(NamedTuple):
|
||||
"""Result of a seat availability check."""
|
||||
|
||||
available: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -104,30 +95,23 @@ def delete_license(db_session: Session) -> bool:
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage directly from database.
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -292,43 +276,3 @@ def get_license_metadata(
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
|
||||
|
||||
def check_seat_availability(
|
||||
db_session: Session,
|
||||
seats_needed: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatAvailabilityResult:
|
||||
"""
|
||||
Check if there are enough seats available to add users.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
seats_needed: Number of seats needed (default 1)
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
SeatAvailabilityResult with available=True if seats are available,
|
||||
or available=False with error_message if limit would be exceeded.
|
||||
Returns available=True if no license exists (self-hosted = unlimited).
|
||||
"""
|
||||
metadata = get_license_metadata(db_session, tenant_id)
|
||||
|
||||
# No license = no enforcement (self-hosted without license)
|
||||
if metadata is None:
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
# Calculate current usage directly from DB (not cache) for accuracy
|
||||
current_used = get_used_seats(tenant_id)
|
||||
total_seats = metadata.seats
|
||||
|
||||
# Use > (not >=) to allow filling to exactly 100% capacity
|
||||
would_exceed_limit = current_used + seats_needed > total_seats
|
||||
if would_exceed_limit:
|
||||
return SeatAvailabilityResult(
|
||||
available=False,
|
||||
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
|
||||
f"cannot add {seats_needed} more user(s).",
|
||||
)
|
||||
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
@@ -3,42 +3,30 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def update_persona_access(
|
||||
def make_persona_private(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -53,13 +41,11 @@ def update_persona_access(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import SearchQuery
|
||||
|
||||
|
||||
def create_search_query(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
query: str,
|
||||
query_expansions: list[str] | None = None,
|
||||
) -> SearchQuery:
|
||||
"""Create and persist a `SearchQuery` row.
|
||||
|
||||
Notes:
|
||||
- `SearchQuery.id` is a UUID PK without a server-side default, so we generate it.
|
||||
- `created_at` is filled by the DB (server_default=now()).
|
||||
"""
|
||||
search_query = SearchQuery(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
query=query,
|
||||
query_expansions=query_expansions,
|
||||
)
|
||||
db_session.add(search_query)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_query)
|
||||
return search_query
|
||||
|
||||
|
||||
def fetch_search_queries_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
filter_days: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[SearchQuery]:
|
||||
"""Fetch `SearchQuery` rows for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
filter_days: Optional time filter. If provided, only rows created within
|
||||
the last `filter_days` days are returned.
|
||||
limit: Optional max number of rows to return.
|
||||
"""
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise ValueError("filter_days must be > 0")
|
||||
|
||||
stmt = select(SearchQuery).where(SearchQuery.user_id == user_id)
|
||||
|
||||
if filter_days is not None and filter_days > 0:
|
||||
cutoff = get_db_current_time(db_session) - timedelta(days=filter_days)
|
||||
stmt = stmt.where(SearchQuery.created_at >= cutoff)
|
||||
|
||||
stmt = stmt.order_by(SearchQuery.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
@@ -7,7 +7,6 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -61,9 +60,6 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -196,9 +195,7 @@ def gdrive_doc_sync(
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
|
||||
@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
@@ -112,9 +111,6 @@ def _get_slack_document_access(
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if isinstance(doc_metadata, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
|
||||
@@ -5,7 +5,6 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -50,9 +49,6 @@ def generic_doc_sync(
|
||||
callback.progress(label, 1)
|
||||
|
||||
for doc in doc_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(
|
||||
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
|
||||
|
||||
@@ -4,10 +4,8 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
@@ -18,17 +16,16 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as ee_query_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
@@ -87,11 +84,6 @@ def get_application() -> FastAPI:
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
else:
|
||||
# License enforcement middleware for self-hosted deployments only
|
||||
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
|
||||
# MT deployments use control plane gating via is_tenant_gated() instead
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
@@ -132,7 +124,7 @@ def get_application() -> FastAPI:
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_query_router)
|
||||
include_router_with_global_prefix_prepended(application, search_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
@@ -151,13 +143,6 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Single message is likely most reliable and generally better for this task
|
||||
# No final reminders at the end since the user query is expected to be short
|
||||
# If it is not short, it should go into the chat flow so we do not need to account for this.
|
||||
KEYWORD_EXPANSION_PROMPT = """
|
||||
Generate a set of keyword-only queries to help find relevant documents for the provided query. \
|
||||
These queries will be passed to a bm25-based keyword search engine. \
|
||||
Provide a single query per line (where each query consists of one or more keywords). \
|
||||
The queries must be purely keywords and not contain any filler natural language. \
|
||||
The each query should have as few keywords as necessary to represent the user's search intent. \
|
||||
If there are no useful expansions, simply return the original query with no additional keyword queries. \
|
||||
CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries.
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_TYPE_PROMPT = """
|
||||
Determine if the provided query is better suited for a keyword search or a semantic search.
|
||||
Respond with "keyword" or "semantic" literally and nothing else.
|
||||
Do not provide any additional text or reasoning to your response.
|
||||
|
||||
CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic".
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
@@ -1,42 +0,0 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
SEARCH_CLASS = "search"
|
||||
CHAT_CLASS = "chat"
|
||||
|
||||
# Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many
|
||||
# use cases.
|
||||
SEARCH_CHAT_PROMPT = f"""
|
||||
Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \
|
||||
Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}".
|
||||
|
||||
# Classification Guidelines:
|
||||
## {SEARCH_CLASS}
|
||||
- If the query consists entirely of keywords or query doesn't require any answer from the AI
|
||||
- If the query is a short statement that seems like a search query rather than a question
|
||||
- If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document
|
||||
|
||||
### Examples of {SEARCH_CLASS} queries:
|
||||
- Find me the document that goes over the onboarding process for a new hire
|
||||
- Pull requests since last week
|
||||
- Sales Runbook AMEA Region
|
||||
- Procurement process
|
||||
- Retrieve the PRD for project X
|
||||
|
||||
## {CHAT_CLASS}
|
||||
- If the query is asking a question that requires an answer rather than a document
|
||||
- If the query is asking for a solution, suggestion, or general help
|
||||
- If the query is seeking information that is on the web and likely not in a company internal document
|
||||
- If the query should be answered without any context from additional documents or searches
|
||||
|
||||
### Examples of {CHAT_CLASS} queries:
|
||||
- What led us to win the deal with company X? (seeking answer)
|
||||
- Google Drive not sync-ing files to my computer (seeking solution)
|
||||
- Review my email: <whatever the email is> (general help)
|
||||
- Write me a script to... (general help)
|
||||
- Cheap flights Europe to Tokyo (information likely found on the web, not internal)
|
||||
|
||||
# User Query:
|
||||
{{user_query}}
|
||||
|
||||
REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
@@ -1,286 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import create_search_query
|
||||
from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
weighted_reciprocal_rank_fusion,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# This is just a heuristic that also happens to work well for the UI/UX
|
||||
# Users would not find it useful to see a huge list of suggested docs
|
||||
# but more than 1 is also likely good as many questions may target more than 1 doc.
|
||||
TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3
|
||||
|
||||
|
||||
def _run_single_search(
|
||||
query: str,
|
||||
filters: BaseFilters | None,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def stream_search_query(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Generator[
|
||||
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
keyword_expansions: list[str] = []
|
||||
|
||||
if request.run_query_expansion:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
keyword_expansions = expand_keywords(
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
)
|
||||
if keyword_expansions:
|
||||
logger.debug(
|
||||
f"Query expansion generated {len(keyword_expansions)} keyword queries"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Query expansion failed: {e}; using original query only.")
|
||||
keyword_expansions = []
|
||||
|
||||
# Build list of all executed queries for tracking
|
||||
all_executed_queries = [original_query] + keyword_expansions
|
||||
|
||||
# TODO remove this check, user should not be None
|
||||
if user is not None:
|
||||
create_search_query(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
query=request.search_query,
|
||||
query_expansions=keyword_expansions if keyword_expansions else None,
|
||||
)
|
||||
|
||||
# Execute search(es)
|
||||
if not keyword_expansions:
|
||||
# Single query (original only) - no threading needed
|
||||
chunks = _run_single_search(
|
||||
query=original_query,
|
||||
filters=request.filters,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
# First query is the original (semantic), rest are keyword expansions
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(
|
||||
query,
|
||||
request.filters,
|
||||
document_index,
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
|
||||
# Run all searches in parallel
|
||||
all_search_results: list[list[InferenceChunk]] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
search_functions,
|
||||
allow_failures=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Separate original query results from keyword expansion results
|
||||
# Note that in rare cases, the original query may have failed and so we may be
|
||||
# just overweighting one set of keyword results, should be not a big deal though.
|
||||
original_result = all_search_results[0] if all_search_results else []
|
||||
keyword_results = all_search_results[1:] if len(all_search_results) > 1 else []
|
||||
|
||||
# Build valid results and weights
|
||||
# Original query (semantic): weight 2.0
|
||||
# Keyword expansions: weight 1.0 each
|
||||
valid_results: list[list[InferenceChunk]] = []
|
||||
weights: list[float] = []
|
||||
|
||||
if original_result:
|
||||
valid_results.append(original_result)
|
||||
weights.append(2.0)
|
||||
|
||||
for keyword_result in keyword_results:
|
||||
if keyword_result:
|
||||
valid_results.append(keyword_result)
|
||||
weights.append(1.0)
|
||||
|
||||
if not valid_results:
|
||||
logger.warning("All parallel searches returned empty results")
|
||||
chunks = []
|
||||
else:
|
||||
chunks = weighted_reciprocal_rank_fusion(
|
||||
ranked_results=valid_results,
|
||||
weights=weights,
|
||||
id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}",
|
||||
)
|
||||
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Truncate to the requested number of hits
|
||||
sections = sections[: request.num_hits]
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
# llm_selected_doc_ids will be:
|
||||
# - None if LLM selection was not requested or failed
|
||||
# - Empty list if LLM selection ran but selected nothing
|
||||
# - List of doc IDs if LLM selection succeeded
|
||||
run_llm_selection = (
|
||||
request.num_docs_fed_to_llm_selection is not None
|
||||
and request.num_docs_fed_to_llm_selection >= 1
|
||||
)
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
llm_selection_failed = False
|
||||
if run_llm_selection and sections:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection]
|
||||
selected_sections, _ = select_sections_for_expansion(
|
||||
sections=sections_to_evaluate,
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION,
|
||||
try_to_fill_to_max=True,
|
||||
)
|
||||
# Extract unique document IDs from selected sections (may be empty)
|
||||
llm_selected_doc_ids = list(
|
||||
dict.fromkeys(
|
||||
section.center_chunk.document_id for section in selected_sections
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"LLM document selection evaluated {len(sections_to_evaluate)} sections, "
|
||||
f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Allowing a blanket exception here as this step is not critical and the rest of the results are still valid
|
||||
logger.warning(f"LLM document selection failed: {e}")
|
||||
llm_selection_failed = True
|
||||
elif run_llm_selection and not sections:
|
||||
# LLM selection requested but no sections to evaluate
|
||||
llm_selected_doc_ids = []
|
||||
|
||||
# Convert to SearchDocWithContent list, optionally including content
|
||||
search_docs = SearchDocWithContent.from_inference_sections(
|
||||
sections,
|
||||
include_content=request.include_content,
|
||||
is_internet=False,
|
||||
)
|
||||
|
||||
# Yield queries packet
|
||||
yield SearchQueriesPacket(all_executed_queries=all_executed_queries)
|
||||
|
||||
# Yield docs packet
|
||||
yield SearchDocsPacket(search_docs=search_docs)
|
||||
|
||||
# Yield LLM selected docs packet if LLM selection was requested
|
||||
# - llm_selected_doc_ids is None if selection failed
|
||||
# - llm_selected_doc_ids is empty list if no docs were selected
|
||||
# - llm_selected_doc_ids is list of IDs if docs were selected
|
||||
if run_llm_selection:
|
||||
yield LLMSelectedDocsPacket(
|
||||
llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids
|
||||
)
|
||||
|
||||
|
||||
def gather_search_stream(
|
||||
packets: Generator[
|
||||
SearchQueriesPacket
|
||||
| SearchDocsPacket
|
||||
| LLMSelectedDocsPacket
|
||||
| SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
],
|
||||
) -> SearchFullResponse:
|
||||
"""
|
||||
Aggregate all streaming packets into SearchFullResponse.
|
||||
"""
|
||||
all_executed_queries: list[str] = []
|
||||
search_docs: list[SearchDocWithContent] = []
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
error: str | None = None
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, SearchQueriesPacket):
|
||||
all_executed_queries = packet.all_executed_queries
|
||||
elif isinstance(packet, SearchDocsPacket):
|
||||
search_docs = packet.search_docs
|
||||
elif isinstance(packet, LLMSelectedDocsPacket):
|
||||
llm_selected_doc_ids = packet.llm_selected_doc_ids
|
||||
elif isinstance(packet, SearchErrorPacket):
|
||||
error = packet.error
|
||||
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=all_executed_queries,
|
||||
search_docs=search_docs,
|
||||
doc_selection_reasoning=None,
|
||||
llm_selected_doc_ids=llm_selected_doc_ids,
|
||||
error=error,
|
||||
)
|
||||
@@ -1,92 +0,0 @@
|
||||
import re
|
||||
|
||||
from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc.
|
||||
CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]')
|
||||
|
||||
|
||||
def _clean_keyword_line(line: str) -> str:
|
||||
"""Clean a keyword line by removing common LLM artifacts.
|
||||
|
||||
Removes brackets, quotes, and other characters that LLMs may accidentally
|
||||
include in their output.
|
||||
"""
|
||||
# Remove common artifacts
|
||||
cleaned = CLEANUP_PATTERN.sub("", line)
|
||||
# Remove leading list markers like "1.", "2.", "-", "*"
|
||||
cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def expand_keywords(
|
||||
user_query: str,
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Expand a user query into multiple keyword-only queries for BM25 search.
|
||||
|
||||
Uses an LLM to generate keyword-based search queries that capture different
|
||||
aspects of the user's search intent. Returns only the expanded queries,
|
||||
not the original query.
|
||||
|
||||
Args:
|
||||
user_query: The original search query from the user
|
||||
llm: Language model to use for keyword expansion
|
||||
|
||||
Returns:
|
||||
List of expanded keyword queries (excluding the original query).
|
||||
Returns empty list if expansion fails or produces no useful expansions.
|
||||
"""
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query))
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Limit output - we only expect a few short keyword queries
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip()
|
||||
|
||||
if not content:
|
||||
logger.warning("Keyword expansion returned empty response.")
|
||||
return []
|
||||
|
||||
# Parse response - each line is a separate keyword query
|
||||
# Clean each line to remove LLM artifacts and drop empty lines
|
||||
parsed_queries = []
|
||||
for line in content.strip().split("\n"):
|
||||
cleaned = _clean_keyword_line(line)
|
||||
if cleaned:
|
||||
parsed_queries.append(cleaned)
|
||||
|
||||
if not parsed_queries:
|
||||
logger.warning("Keyword expansion parsing returned no queries.")
|
||||
return []
|
||||
|
||||
# Filter out duplicates and queries that match the original
|
||||
expanded_queries: list[str] = []
|
||||
seen_lower: set[str] = {user_query.lower()}
|
||||
for query in parsed_queries:
|
||||
query_lower = query.lower()
|
||||
if query_lower not in seen_lower:
|
||||
seen_lower.add(query_lower)
|
||||
expanded_queries.append(query)
|
||||
|
||||
logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries")
|
||||
return expanded_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword expansion failed: {e}")
|
||||
return []
|
||||
@@ -1,50 +0,0 @@
|
||||
from ee.onyx.prompts.search_flow_classification import CHAT_CLASS
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def classify_is_search_flow(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query))
|
||||
]
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout
|
||||
timeout_override=2,
|
||||
# Well more than necessary but just to ensure completion and in case it succeeds with classifying but
|
||||
# ends up rambling
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip().lower()
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Search flow classification returned empty response; defaulting to chat flow."
|
||||
)
|
||||
return False
|
||||
|
||||
# Prefer chat if both appear.
|
||||
if CHAT_CLASS in content:
|
||||
return False
|
||||
if SEARCH_CLASS in content:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Search flow classification returned unexpected response; defaulting to chat flow. Response=%r",
|
||||
content,
|
||||
)
|
||||
return False
|
||||
@@ -19,9 +19,9 @@ from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -10,16 +10,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
("/admin/billing/stripe-publishable-key", {"GET"}),
|
||||
# Proxy endpoints use license-based auth, not user auth
|
||||
("/proxy/create-checkout-session", {"POST"}),
|
||||
("/proxy/claim-license", {"POST"}),
|
||||
("/proxy/create-customer-portal-session", {"POST"}),
|
||||
("/proxy/billing-information", {"GET"}),
|
||||
("/proxy/license/{tenant_id}", {"GET"}),
|
||||
("/proxy/seats/update", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,264 +0,0 @@
|
||||
"""Unified Billing API endpoints.
|
||||
|
||||
These endpoints provide Stripe billing functionality for both cloud and
|
||||
self-hosted deployments. The service layer routes requests appropriately:
|
||||
|
||||
- Self-hosted: Routes through cloud data plane proxy
|
||||
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Backend /admin/billing/* → Control plane
|
||||
|
||||
License claiming is handled separately by /license/claim endpoint (self-hosted only).
|
||||
|
||||
Migration Note (ENG-3533):
|
||||
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
|
||||
- /tenants/billing-information -> /admin/billing/billing-information
|
||||
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
|
||||
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
|
||||
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
|
||||
|
||||
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_customer_portal_session as create_portal_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/admin/billing")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _get_license_data(db_session: Session) -> str | None:
|
||||
"""Get license data from database if exists (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return None
|
||||
license_record = get_license(db_session)
|
||||
return license_record.license_data if license_record else None
|
||||
|
||||
|
||||
def _get_tenant_id() -> str | None:
|
||||
"""Get tenant ID for cloud deployments."""
|
||||
if MULTI_TENANT:
|
||||
return get_current_tenant_id()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
|
||||
For new customers, no license/tenant is required.
|
||||
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
|
||||
|
||||
After checkout completion:
|
||||
- Self-hosted: Use /license/claim to retrieve the license
|
||||
- Cloud: Subscription is automatically activated
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
email = request.email if request else None
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
|
||||
Requires existing license (self-hosted) or active tenant (cloud).
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
|
||||
Returns subscription status and details from Stripe.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted without license = no subscription
|
||||
if not MULTI_TENANT and not license_data:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
try:
|
||||
return await get_billing_service(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Handles Stripe proration and license regeneration via control plane.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
try:
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""Fetch the Stripe publishable key.
|
||||
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
@@ -1,75 +0,0 @@
|
||||
"""Pydantic models for the billing API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe checkout session URL."""
|
||||
|
||||
stripe_checkout_url: str
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe customer portal session."""
|
||||
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe customer portal URL."""
|
||||
|
||||
stripe_customer_portal_url: str
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
"""Billing information for the current subscription."""
|
||||
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: datetime | None = None
|
||||
current_period_end: datetime | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: datetime | None = None
|
||||
trial_start: datetime | None = None
|
||||
trial_end: datetime | None = None
|
||||
payment_method_enabled: bool = False
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
"""Response when no subscription exists."""
|
||||
|
||||
subscribed: bool = False
|
||||
|
||||
|
||||
class SeatUpdateRequest(BaseModel):
|
||||
"""Request to update seat count."""
|
||||
|
||||
new_seat_count: int
|
||||
|
||||
|
||||
class SeatUpdateResponse(BaseModel):
|
||||
"""Response from seat update operation."""
|
||||
|
||||
success: bool
|
||||
current_seats: int
|
||||
used_seats: int
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
"""Response containing the Stripe publishable key."""
|
||||
|
||||
publishable_key: str
|
||||
@@ -1,267 +0,0 @@
|
||||
"""Service layer for billing operations.
|
||||
|
||||
This module provides functions for billing operations that route differently
|
||||
based on deployment type:
|
||||
|
||||
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
|
||||
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Cloud backend → Control plane
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# HTTP request timeout for billing service calls
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
Self-hosted instances authenticate with their license.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if license_data:
|
||||
headers["Authorization"] = f"Bearer {license_data}"
|
||||
return headers
|
||||
|
||||
|
||||
def _get_direct_headers() -> dict[str, str]:
|
||||
"""Build headers for direct control plane requests (cloud).
|
||||
|
||||
Cloud instances authenticate with JWT.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return CONTROL_PLANE_API_BASE_URL
|
||||
return f"{CLOUD_DATA_PLANE_URL}/proxy"
|
||||
|
||||
|
||||
def _get_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Get appropriate headers based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return _get_direct_headers()
|
||||
return _get_proxy_headers(license_data)
|
||||
|
||||
|
||||
async def _make_billing_request(
|
||||
method: Literal["GET", "POST"],
|
||||
path: str,
|
||||
license_data: str | None = None,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
error_message: str = "Billing service request failed",
|
||||
) -> dict:
|
||||
"""Make an HTTP request to the billing service.
|
||||
|
||||
Consolidates the common HTTP request pattern used by all billing operations.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET or POST)
|
||||
path: URL path (appended to base URL)
|
||||
license_data: License for authentication (self-hosted)
|
||||
body: Request body for POST requests
|
||||
params: Query parameters for GET requests
|
||||
error_message: Default error message if request fails
|
||||
|
||||
Returns:
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
base_url = _get_base_url()
|
||||
url = f"{base_url}{path}"
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = error_message
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
billing_period: str = "monthly",
|
||||
email: str | None = None,
|
||||
license_data: str | None = None,
|
||||
redirect_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session.
|
||||
|
||||
Args:
|
||||
billing_period: "monthly" or "annual"
|
||||
email: Customer email for new subscriptions
|
||||
license_data: Existing license for renewals (self-hosted)
|
||||
redirect_url: URL to redirect after successful checkout
|
||||
tenant_id: Tenant ID (cloud only, for renewals)
|
||||
|
||||
Returns:
|
||||
CreateCheckoutSessionResponse with checkout URL
|
||||
"""
|
||||
body: dict = {"billing_period": billing_period}
|
||||
if email:
|
||||
body["email"] = email
|
||||
if redirect_url:
|
||||
body["redirect_url"] = redirect_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-checkout-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create checkout session",
|
||||
)
|
||||
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
|
||||
|
||||
|
||||
async def create_customer_portal_session(
|
||||
license_data: str | None = None,
|
||||
return_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
return_url: URL to return to after portal session
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
CreateCustomerPortalSessionResponse with portal URL
|
||||
"""
|
||||
body: dict = {}
|
||||
if return_url:
|
||||
body["return_url"] = return_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-customer-portal-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create customer portal session",
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
|
||||
|
||||
|
||||
async def get_billing_information(
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Fetch billing information.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
BillingInformationResponse or SubscriptionStatusResponse if no subscription
|
||||
"""
|
||||
params = {}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="GET",
|
||||
path="/billing-information",
|
||||
license_data=license_data,
|
||||
params=params or None,
|
||||
error_message="Failed to fetch billing information",
|
||||
)
|
||||
|
||||
# Check if no subscription
|
||||
if isinstance(data, dict) and data.get("subscribed") is False:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
return BillingInformationResponse(**data)
|
||||
|
||||
|
||||
async def update_seat_count(
|
||||
new_seat_count: int,
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Args:
|
||||
new_seat_count: New number of seats
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
SeatUpdateResponse with updated seat information
|
||||
"""
|
||||
body: dict = {"new_seat_count": new_seat_count}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/seats/update",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to update seat count",
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=data.get("success", False),
|
||||
current_seats=data.get("current_seats", 0),
|
||||
used_seats=data.get("used_seats", 0),
|
||||
message=data.get("message"),
|
||||
)
|
||||
@@ -1,14 +1,4 @@
|
||||
"""License API endpoints for self-hosted deployments.
|
||||
|
||||
These endpoints allow self-hosted Onyx instances to:
|
||||
1. Claim a license after Stripe checkout (via cloud data plane proxy)
|
||||
2. Upload a license file manually (for air-gapped deployments)
|
||||
3. View license status and seat usage
|
||||
4. Refresh/delete the local license
|
||||
|
||||
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
|
||||
Cloud licensing is managed via the control plane and gated_tenants Redis key.
|
||||
"""
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
@@ -19,7 +9,6 @@ from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
@@ -31,11 +20,13 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -88,80 +79,81 @@ async def get_seat_usage(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str,
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Claim a license after Stripe checkout (self-hosted only).
|
||||
|
||||
After a user completes Stripe checkout, they're redirected back with a
|
||||
session_id. This endpoint exchanges that session_id for a signed license
|
||||
via the cloud data plane proxy.
|
||||
|
||||
Flow:
|
||||
1. Self-hosted frontend redirects to Stripe checkout (via cloud proxy)
|
||||
2. User completes payment
|
||||
3. Stripe redirects back to self-hosted instance with session_id
|
||||
4. Frontend calls this endpoint with session_id
|
||||
5. We call cloud data plane /proxy/claim-license to get the signed license
|
||||
6. License is stored locally and cached
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
# Call cloud data plane to claim the license
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"session_id": session_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
license_data = data.get("license")
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in DB
|
||||
upsert_license(db_session, license_data)
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
logger.info(
|
||||
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
|
||||
)
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
detail = "Failed to claim license"
|
||||
try:
|
||||
error_data = e.response.json() if e.response is not None else {}
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@@ -172,36 +164,33 @@ async def upload_license(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually (self-hosted only).
|
||||
|
||||
Used for air-gapped deployments where the cloud data plane is not accessible.
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
@@ -216,10 +205,8 @@ async def refresh_license_cache_endpoint(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the local database.
|
||||
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
Does NOT fetch from control plane - use /claim for that.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
@@ -246,15 +233,9 @@ async def delete_license(
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
"""Middleware to enforce license status for SELF-HOSTED deployments only.
|
||||
|
||||
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
|
||||
Multi-tenant gating is handled separately by the control plane via the
|
||||
/tenants/product-gating endpoint and is_tenant_gated() checks.
|
||||
|
||||
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
|
||||
============================================================
|
||||
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
|
||||
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false (default):
|
||||
Middleware is disabled. EE features are controlled solely by
|
||||
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=true:
|
||||
Middleware actively enforces license status. EE features require
|
||||
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
|
||||
|
||||
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
|
||||
enforcement will be the only mechanism for gating EE features.
|
||||
|
||||
License Enforcement States (when enabled)
|
||||
=========================================
|
||||
For self-hosted deployments:
|
||||
|
||||
1. No license (never subscribed):
|
||||
- Allow community features (basic connectors, search, chat)
|
||||
- Block EE-only features (analytics, user groups, etc.)
|
||||
|
||||
2. GATED_ACCESS (fully expired):
|
||||
- Block all routes except billing/auth/license
|
||||
- User must renew subscription to continue
|
||||
|
||||
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
|
||||
- Full access to all EE features
|
||||
- Seat limits enforced
|
||||
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
|
||||
from ee.onyx.configs.license_enforcement_config import (
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(
|
||||
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
|
||||
)
|
||||
|
||||
|
||||
def _is_ee_only_path(path: str) -> bool:
|
||||
"""Check if path requires EE license (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
logger.info("License enforcement middleware registered")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_license(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Block requests when license is expired/gated."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path.startswith("/api"):
|
||||
path = path[4:]
|
||||
|
||||
if _is_path_allowed(path):
|
||||
return await call_next(request)
|
||||
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
|
||||
# If no cached metadata, check database (cache may have been cleared)
|
||||
if not metadata:
|
||||
logger.debug(
|
||||
"[license_enforcement] No cached license, checking database..."
|
||||
)
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
if metadata:
|
||||
logger.info(
|
||||
"[license_enforcement] Loaded license from database"
|
||||
)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"[license_enforcement] Failed to check database for license: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# User HAS a license (current or expired)
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
# License fully expired - gate the user
|
||||
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
|
||||
# they don't block access
|
||||
is_gated = True
|
||||
else:
|
||||
# License is active - check seat limit
|
||||
# used_seats in cache is kept accurate via invalidation
|
||||
# when users are added/removed
|
||||
if metadata.used_seats > metadata.seats:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request: "
|
||||
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "seat_limit_exceeded",
|
||||
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
|
||||
"used_seats": metadata.used_seats,
|
||||
"seats": metadata.seats,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No license in cache OR database = never subscribed
|
||||
# Allow community features, but block EE-only features
|
||||
if _is_ee_only_path(path):
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking EE-only path (no license): {path}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "enterprise_license_required",
|
||||
"message": "This feature requires an Enterprise license. "
|
||||
"Please upgrade to access this functionality.",
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request (license expired): {path}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "license_expired",
|
||||
"message": "Your subscription has expired. Please update your billing.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
214
backend/ee/onyx/server/query_and_chat/chat_backend.py
Normal file
214
backend/ee/onyx/server/query_and_chat/chat_backend.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
@router.post("/send-message-simple-api")
|
||||
def handle_simplified_chat_message(
|
||||
chat_message_req: BasicCreateChatMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information"""
|
||||
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
# Handle chat session creation if chat_session_id is not provided
|
||||
if chat_message_req.chat_session_id is None:
|
||||
if chat_message_req.persona_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either chat_session_id or persona_id must be provided",
|
||||
)
|
||||
|
||||
# Create a new chat session with the provided persona_id
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave empty for simple API
|
||||
user_id=user.id if user else None,
|
||||
persona_id=chat_message_req.persona_id,
|
||||
)
|
||||
chat_session_id = new_chat_session.id
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
else:
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)[-1]
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
def handle_send_message_simple_with_history(
|
||||
req: BasicCreateChatMessageWithHistoryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="One or more chat messages were empty"
|
||||
)
|
||||
|
||||
if msg.role != expected_role:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
|
||||
)
|
||||
if expected_role == MessageType.USER:
|
||||
expected_role = MessageType.ASSISTANT
|
||||
else:
|
||||
expected_role = MessageType.USER
|
||||
|
||||
query = req.messages[-1].message
|
||||
msg_history = req.messages[:-1]
|
||||
|
||||
logger.notice(f"Received new simple with history chat message: {query}")
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_message = root_message
|
||||
for msg in msg_history:
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -1,12 +1,18 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -19,89 +25,119 @@ class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SearchFlowClassificationRequest(BaseModel):
|
||||
user_query: str
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
|
||||
|
||||
class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
|
||||
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
|
||||
Note, for simplicity this option only allows for a single linear chain of messages
|
||||
"""
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
chat_session_id: UUID | None = None
|
||||
# Optional persona_id to create a new chat session if chat_session_id is not provided
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
raise ValueError("Either chat_session_id or persona_id must be provided")
|
||||
return self
|
||||
|
||||
|
||||
class SearchDocWithContent(SearchDoc):
|
||||
# Allows None because this is determined by a flag but the object used in code
|
||||
# of the search path uses this type
|
||||
content: str | None
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def from_inference_sections(
|
||||
cls,
|
||||
sections: Sequence[InferenceSection],
|
||||
include_content: bool = False,
|
||||
is_internet: bool = False,
|
||||
) -> list["SearchDocWithContent"]:
|
||||
"""Convert InferenceSections to SearchDocWithContent objects.
|
||||
|
||||
Args:
|
||||
sections: Sequence of InferenceSection objects
|
||||
include_content: If True, populate content field with combined_content
|
||||
is_internet: Whether these are internet search results
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
match_highlights: list[str]
|
||||
source_type: DocumentSource
|
||||
metadata: dict | None
|
||||
|
||||
Returns:
|
||||
List of SearchDocWithContent with optional content
|
||||
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"],
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
if not sections:
|
||||
return []
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
return [
|
||||
cls(
|
||||
document_id=(chunk := section.center_chunk).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=is_internet,
|
||||
content=section.combined_content if include_content else None,
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
for section in sections
|
||||
]
|
||||
|
||||
|
||||
class SearchFullResponse(BaseModel):
|
||||
all_executed_queries: list[str]
|
||||
search_docs: list[SearchDocWithContent]
|
||||
# Reasoning tokens output by the LLM for the document selection
|
||||
doc_selection_reasoning: str | None = None
|
||||
# This a list of document ids that are in the search_docs list
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
# Error message if the search failed partway through
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class SearchQueryResponse(BaseModel):
|
||||
query: str
|
||||
query_expansions: list[str] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SearchHistoryResponse(BaseModel):
|
||||
search_queries: list[SearchQueryResponse]
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import fetch_search_queries_for_user
|
||||
from ee.onyx.search.process_search_query import gather_search_stream
|
||||
from ee.onyx.search.process_search_query import stream_search_query
|
||||
from ee.onyx.secondary_llm_flows.search_flow_classification import (
|
||||
classify_is_search_flow,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/search")
|
||||
|
||||
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
# This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document
|
||||
# Most likely something needs to be done with the text included so we'll just classify it as a chat flow
|
||||
if len(query) > 200:
|
||||
return SearchFlowClassificationResponse(is_search_flow=False)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=get_current_tenant_id(),
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
is_search_flow = classify_is_search_flow(query=query, llm=llm)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Search flow classification failed; defaulting to chat flow",
|
||||
exc_info=e,
|
||||
)
|
||||
is_search_flow = False
|
||||
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
packets = stream_search_query(request, user, db_session)
|
||||
return gather_search_stream(packets)
|
||||
except NotImplementedError as e:
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=[],
|
||||
search_docs=[],
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Streaming path
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as streaming_db_session:
|
||||
for packet in stream_search_query(request, user, streaming_db_session):
|
||||
yield get_json_line(packet.model_dump())
|
||||
except NotImplementedError as e:
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in search streaming")
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/search-history")
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
Fetch past search queries for the authenticated user.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of queries to return (default 100)
|
||||
filter_days: Only return queries from the last N days (optional)
|
||||
|
||||
Returns:
|
||||
SearchHistoryResponse with list of search queries, ordered by most recent first.
|
||||
"""
|
||||
# Validate limit
|
||||
if limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be greater than 0",
|
||||
)
|
||||
if limit > 1000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be at most 1000",
|
||||
)
|
||||
|
||||
# Validate filter_days
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="filter_days must be greater than 0",
|
||||
)
|
||||
|
||||
# TODO(yuhong) remove this
|
||||
if user is None:
|
||||
# Return empty list for unauthenticated users
|
||||
return SearchHistoryResponse(search_queries=[])
|
||||
|
||||
search_queries = fetch_search_queries_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
filter_days=filter_days,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return SearchHistoryResponse(
|
||||
search_queries=[
|
||||
SearchQueryResponse(
|
||||
query=sq.query,
|
||||
query_expansions=sq.query_expansions,
|
||||
created_at=sq.created_at,
|
||||
)
|
||||
for sq in search_queries
|
||||
]
|
||||
)
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
|
||||
|
||||
class SearchQueriesPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_queries"] = "search_queries"
|
||||
all_executed_queries: list[str]
|
||||
|
||||
|
||||
class SearchDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_docs"] = "search_docs"
|
||||
search_docs: list[SearchDocWithContent]
|
||||
|
||||
|
||||
class SearchErrorPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_error"] = "search_error"
|
||||
error: str
|
||||
|
||||
|
||||
class LLMSelectedDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["llm_selected_docs"] = "llm_selected_docs"
|
||||
# None if LLM selection failed, empty list if no docs selected, list of IDs otherwise
|
||||
llm_selected_doc_ids: list[str] | None
|
||||
@@ -32,7 +32,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
@@ -49,6 +48,7 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
|
||||
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
|
||||
|
||||
|
||||
def check_ee_features_enabled() -> bool:
|
||||
"""EE version: checks if EE features should be available.
|
||||
|
||||
Returns True if:
|
||||
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
|
||||
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
|
||||
- Self-hosted with a valid (non-expired) license
|
||||
|
||||
Returns False if:
|
||||
- Self-hosted with no license (never subscribed)
|
||||
- Self-hosted with expired license
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
# License enforcement disabled - allow EE features (legacy behavior)
|
||||
return True
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
|
||||
return True
|
||||
|
||||
# Self-hosted with enforcement - check for valid license
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license for EE features: {e}")
|
||||
# Fail closed - if Redis is down, other things will break anyway
|
||||
return False
|
||||
|
||||
# No license or GATED_ACCESS - no EE features
|
||||
return False
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license indicates GATED_ACCESS (fully expired).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
|
||||
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
|
||||
allowing the product to function normally without license checks.
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return settings
|
||||
|
||||
if MULTI_TENANT:
|
||||
return settings
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
# No license = user hasn't purchased yet, allow access for upgrade flow
|
||||
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
return settings
|
||||
@@ -1,14 +1,10 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.server.usage_limits import NO_LIMIT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -16,12 +12,9 @@ logger = setup_logger()
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
_last_fetch_time: float = 0.0
|
||||
_FETCH_INTERVAL = 60 * 60 * 24 # 24 hours
|
||||
_ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed)
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None:
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
@@ -52,52 +45,33 @@ def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return (
|
||||
result or None
|
||||
) # if empty dictionary, something went wrong and we shouldn't enforce limits
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return None
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return None
|
||||
return {}
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> None:
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
global _last_fetch_time
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
|
||||
_last_fetch_time = time.time()
|
||||
|
||||
# use the new result if it exists, otherwise use the old result
|
||||
# (prevents us from updating to a failed fetch result)
|
||||
_tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
|
||||
|
||||
def unlimited(tenant_id: str) -> TenantUsageLimitOverrides:
|
||||
return TenantUsageLimitOverrides(
|
||||
tenant_id=tenant_id,
|
||||
llm_cost_cents_trial=NO_LIMIT,
|
||||
llm_cost_cents_paid=NO_LIMIT,
|
||||
chunks_indexed_trial=NO_LIMIT,
|
||||
chunks_indexed_paid=NO_LIMIT,
|
||||
api_calls_trial=NO_LIMIT,
|
||||
api_calls_paid=NO_LIMIT,
|
||||
non_streaming_calls_trial=NO_LIMIT,
|
||||
non_streaming_calls_paid=NO_LIMIT,
|
||||
)
|
||||
return overrides
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
@@ -112,22 +86,7 @@ def get_tenant_usage_limit_overrides(
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
|
||||
if DEV_MODE: # in dev mode, we return unlimited limits for all tenants
|
||||
return unlimited(tenant_id)
|
||||
|
||||
global _tenant_usage_limit_overrides
|
||||
time_since = time.time() - _last_fetch_time
|
||||
if (
|
||||
_tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL
|
||||
) or (time_since > _FETCH_INTERVAL):
|
||||
logger.debug(
|
||||
f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}"
|
||||
)
|
||||
|
||||
load_usage_limit_overrides()
|
||||
|
||||
# If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone.
|
||||
if _tenant_usage_limit_overrides is None or DEV_MODE:
|
||||
return unlimited(tenant_id)
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import APIRouter
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.proxy import router as proxy_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
@@ -23,4 +22,3 @@ router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
router.include_router(proxy_router)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,21 +16,15 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -76,46 +70,24 @@ def fetch_billing_information(
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
|
||||
"""
|
||||
Fetch a Stripe customer portal session URL from the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
|
||||
payload = {"tenant_id": tenant_id}
|
||||
if return_url:
|
||||
payload["return_url"] = return_url
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["url"]
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,59 +1,33 @@
|
||||
"""Billing API endpoints for cloud multi-tenant deployments.
|
||||
|
||||
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
|
||||
which provides a unified API for both self-hosted and cloud deployments.
|
||||
|
||||
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
|
||||
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
|
||||
Current endpoints to migrate:
|
||||
- GET /tenants/billing-information -> GET /admin/billing/information
|
||||
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
|
||||
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
|
||||
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
|
||||
|
||||
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
|
||||
and are NOT part of this migration - they stay here.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
@@ -108,13 +82,21 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"url": portal_url}
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -122,82 +104,15 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""
|
||||
Fetch the Stripe publishable key.
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -74,12 +73,6 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
@@ -105,7 +98,3 @@ class PendingUserSnapshot(BaseModel):
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
publishable_key: str
|
||||
|
||||
@@ -65,9 +65,3 @@ def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
|
||||
def is_tenant_gated(tenant_id: str) -> bool:
|
||||
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
"""Proxy endpoints for billing operations.
|
||||
|
||||
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
|
||||
for self-hosted instances to reach the control plane.
|
||||
|
||||
Flow:
|
||||
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
|
||||
|
||||
Self-hosted instances call these endpoints with their license in the Authorization
|
||||
header. The cloud data plane validates the license signature and forwards the
|
||||
request to the control plane using JWT authentication.
|
||||
|
||||
Auth levels by endpoint:
|
||||
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
|
||||
- /claim-license: Session ID based (one-time after Stripe payment)
|
||||
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
|
||||
- /billing-information: Valid license required
|
||||
- /license/{tenant_id}: Valid license required
|
||||
- /seats/update: Valid license required
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import is_license_valid
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/proxy")
|
||||
|
||||
|
||||
def _check_license_enforcement_enabled() -> None:
|
||||
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Proxy endpoints are only available on cloud data plane",
|
||||
)
|
||||
|
||||
|
||||
def _extract_license_from_header(
|
||||
authorization: str | None,
|
||||
required: bool = True,
|
||||
) -> str | None:
|
||||
"""Extract license data from Authorization header.
|
||||
|
||||
Self-hosted instances authenticate to these proxy endpoints by sending their
|
||||
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
|
||||
|
||||
We use the Bearer scheme (RFC 6750) because:
|
||||
1. It's the standard HTTP auth scheme for token-based authentication
|
||||
2. The license blob is cryptographically signed (RSA), so it's self-validating
|
||||
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
|
||||
|
||||
The license data is the base64-encoded signed blob that contains tenant_id,
|
||||
seats, expiration, etc. We verify the signature to authenticate the caller.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value (e.g., "Bearer <license>")
|
||||
required: If True, raise 401 when header is missing/invalid
|
||||
|
||||
Returns:
|
||||
License data string (base64-encoded), or None if not required and missing
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if required and header is missing/invalid
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
if required:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Missing or invalid authorization header"
|
||||
)
|
||||
return None
|
||||
|
||||
return authorization.split(" ", 1)[1]
|
||||
|
||||
|
||||
def verify_license_auth(
|
||||
license_data: str,
|
||||
allow_expired: bool = False,
|
||||
) -> LicensePayload:
|
||||
"""Verify license signature and optionally check expiry.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded signed license blob
|
||||
allow_expired: If True, accept expired licenses (for renewal flows)
|
||||
|
||||
Returns:
|
||||
LicensePayload if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If license is invalid or expired (when not allowed)
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
|
||||
|
||||
if not allow_expired and not is_license_valid(payload):
|
||||
raise HTTPException(status_code=401, detail="License has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require valid (non-expired) license.
|
||||
|
||||
Used for endpoints that require an active subscription.
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=False)
|
||||
|
||||
|
||||
async def get_license_payload_allow_expired(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require license with valid signature, expired OK.
|
||||
|
||||
Used for endpoints needed to fix payment issues (portal, renewal checkout).
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def get_optional_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload | None:
|
||||
"""Dependency: Optional license auth (for checkout - new customers have none).
|
||||
|
||||
Returns None if no license provided, otherwise validates and returns payload.
|
||||
Expired licenses are allowed for renewal flows.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
license_data = _extract_license_from_header(authorization, required=False)
|
||||
if license_data is None:
|
||||
return None
|
||||
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def forward_to_control_plane(
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Forward a request to the control plane with proper authentication."""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
detail = "Control plane request failed"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"Control plane returned {status_code}: {detail}")
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
|
||||
"""Store license in database and update Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
license_data: Base64-encoded signed license blob
|
||||
"""
|
||||
try:
|
||||
# Verify before storing
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in database using the specific tenant's schema
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
# Update Redis cache
|
||||
update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license: {e}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store license")
|
||||
raise
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
# Redirect URL after successful checkout - self-hosted passes their instance URL
|
||||
redirect_url: str | None = None
|
||||
# Cancel URL when user exits checkout - returns to upgrade page
|
||||
cancel_url: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def proxy_create_checkout_session(
|
||||
request_body: CreateCheckoutSessionRequest,
|
||||
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Proxy checkout session creation to control plane.
|
||||
|
||||
Auth: Optional license (new customers don't have one yet).
|
||||
If license provided, expired is OK (for renewals).
|
||||
"""
|
||||
# license_payload is None for new customers who don't have a license yet.
|
||||
# In that case, tenant_id is omitted from the request body and the control
|
||||
# plane will create a new tenant during checkout completion.
|
||||
tenant_id = license_payload.tenant_id if license_payload else None
|
||||
|
||||
body: dict = {
|
||||
"billing_period": request_body.billing_period,
|
||||
}
|
||||
if tenant_id:
|
||||
body["tenant_id"] = tenant_id
|
||||
if request_body.email:
|
||||
body["email"] = request_body.email
|
||||
if request_body.redirect_url:
|
||||
body["redirect_url"] = request_body.redirect_url
|
||||
if request_body.cancel_url:
|
||||
body["cancel_url"] = request_body.cancel_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-checkout-session", body=body
|
||||
)
|
||||
return CreateCheckoutSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class ClaimLicenseRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ClaimLicenseResponse(BaseModel):
|
||||
tenant_id: str
|
||||
license: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@router.post("/claim-license")
|
||||
async def proxy_claim_license(
|
||||
request_body: ClaimLicenseRequest,
|
||||
) -> ClaimLicenseResponse:
|
||||
"""Claim a license after successful Stripe checkout.
|
||||
|
||||
Auth: Session ID based (one-time use after payment).
|
||||
The control plane verifies the session_id is valid and unclaimed.
|
||||
|
||||
Returns the license to the caller. For self-hosted instances, they will
|
||||
store the license locally. The cloud DP doesn't need to store it.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/claim-license",
|
||||
body={"session_id": request_body.session_id},
|
||||
)
|
||||
|
||||
tenant_id = result.get("tenant_id")
|
||||
license_data = result.get("license")
|
||||
|
||||
if not tenant_id or not license_data:
|
||||
logger.error(f"Control plane returned incomplete claim response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
return ClaimLicenseResponse(
|
||||
tenant_id=tenant_id,
|
||||
license=license_data,
|
||||
message="License claimed successfully",
|
||||
)
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def proxy_create_customer_portal_session(
|
||||
request_body: CreateCustomerPortalSessionRequest | None = None,
|
||||
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Proxy customer portal session creation to control plane.
|
||||
|
||||
Auth: License required, expired OK (need portal to fix payment issues).
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
body: dict = {"tenant_id": tenant_id}
|
||||
if request_body and request_body.return_url:
|
||||
body["return_url"] = request_body.return_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-customer-portal-session", body=body
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: str | None = None
|
||||
current_period_end: str | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: str | None = None
|
||||
trial_start: str | None = None
|
||||
trial_end: str | None = None
|
||||
payment_method_enabled: bool = False
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def proxy_billing_information(
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> BillingInformationResponse:
|
||||
"""Proxy billing information request to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"GET", "/billing-information", params={"tenant_id": tenant_id}
|
||||
)
|
||||
# Add tenant_id from license if not in response (control plane may not include it)
|
||||
if "tenant_id" not in result:
|
||||
result["tenant_id"] = tenant_id
|
||||
return BillingInformationResponse(**result)
|
||||
|
||||
|
||||
class LicenseFetchResponse(BaseModel):
|
||||
license: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
@router.get("/license/{tenant_id}")
|
||||
async def proxy_license_fetch(
|
||||
tenant_id: str,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> LicenseFetchResponse:
|
||||
"""Proxy license fetch to control plane.
|
||||
|
||||
Auth: Valid license required.
|
||||
The tenant_id in path must match the authenticated tenant.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
if tenant_id != license_payload.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot fetch license for a different tenant",
|
||||
)
|
||||
|
||||
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
|
||||
|
||||
# Auto-store the refreshed license
|
||||
license_data = result.get("license")
|
||||
if not license_data:
|
||||
logger.error(f"Control plane returned incomplete license response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
fetch_and_store_license(tenant_id, license_data)
|
||||
|
||||
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def proxy_seat_update(
|
||||
request_body: SeatUpdateRequest,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Proxy seat update to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
Handles Stripe proration and license regeneration.
|
||||
"""
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/seats/update",
|
||||
body={
|
||||
"tenant_id": tenant_id,
|
||||
"new_seat_count": request_body.new_seat_count,
|
||||
},
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=result.get("success", False),
|
||||
current_seats=result.get("current_seats", 0),
|
||||
used_seats=result.get("used_seats", 0),
|
||||
message=result.get("message"),
|
||||
)
|
||||
@@ -1,7 +1,6 @@
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
@@ -48,8 +47,6 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
@@ -73,104 +70,49 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
|
||||
If a user already has an active mapping to a different tenant, they receive
|
||||
an inactive mapping (invitation) to this tenant. They can accept the
|
||||
invitation later to switch tenants.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if adding active users would exceed seat limit
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
|
||||
|
||||
unique_emails = set(emails)
|
||||
if not unique_emails:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
# Batch query 1: Get all existing mappings for these emails to this tenant
|
||||
# Lock rows to prevent concurrent modifications
|
||||
existing_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
emails_with_mapping = {m.email for m in existing_mappings}
|
||||
|
||||
# Batch query 2: Get all active mappings for these emails (any tenant)
|
||||
active_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails_with_active_mapping = {m.email for m in active_mappings}
|
||||
|
||||
# Determine which users will consume a new seat.
|
||||
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
|
||||
# and don't consume seats until they accept. Only users without any active
|
||||
# mapping will get an ACTIVE mapping and consume a seat immediately.
|
||||
emails_consuming_seats = {
|
||||
email
|
||||
for email in unique_emails
|
||||
if email not in emails_with_mapping
|
||||
and email not in emails_with_active_mapping
|
||||
}
|
||||
|
||||
# Check seat availability inside the transaction to prevent race conditions.
|
||||
# Note: ALL users in unique_emails still get added below - this check only
|
||||
# validates we have capacity for users who will consume seats immediately.
|
||||
if emails_consuming_seats:
|
||||
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session,
|
||||
seats_needed=len(emails_consuming_seats),
|
||||
tenant_id=tenant_id,
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
|
||||
# Add mappings for emails that don't already have one to this tenant
|
||||
for email in unique_emails:
|
||||
if email in emails_with_mapping:
|
||||
continue
|
||||
|
||||
# Create mapping: inactive if user belongs to another tenant (invitation),
|
||||
# active otherwise
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=email not in emails_with_active_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
except HTTPException:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
@@ -193,9 +135,6 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
@@ -210,9 +149,6 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -241,9 +177,6 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
@@ -262,42 +195,19 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if accepting would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# Lock the user's mappings first to prevent race conditions.
|
||||
# This ensures no concurrent request can modify this user's mappings
|
||||
# while we check seats and activate.
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check seat availability within the same logical operation.
|
||||
# Note: This queries fresh data from DB, not cache.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session, seats_needed=1, tenant_id=tenant_id
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
@@ -327,9 +237,6 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
@@ -390,41 +297,16 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant.
|
||||
|
||||
A user counts toward the seat count if:
|
||||
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
|
||||
2. AND the User is active (User.is_active == True)
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
from onyx.db.models import User
|
||||
|
||||
# First get all emails with active mappings to this tenant
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
active_mapping_emails = (
|
||||
db_session.query(UserTenantMapping.email)
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails = [email for (email,) in active_mapping_emails]
|
||||
|
||||
if not emails:
|
||||
return 0
|
||||
|
||||
# Now count how many of those users are actually active in the tenant's User table
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
@@ -17,6 +16,7 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
@@ -28,7 +31,13 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
return billing_info.status == "trialing"
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -5,7 +5,6 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
@@ -20,27 +19,21 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Path to the license public key file
|
||||
_LICENSE_PUBLIC_KEY_PATH = (
|
||||
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
|
||||
)
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from file, with env var override."""
|
||||
# Allow env var override for flexibility
|
||||
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
|
||||
|
||||
if not key_pem:
|
||||
# Read from file
|
||||
if not _LICENSE_PUBLIC_KEY_PATH.exists():
|
||||
raise ValueError(
|
||||
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
|
||||
|
||||
key = serialization.load_pem_public_key(key_pem.encode())
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
@@ -60,21 +53,17 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
|
||||
# Parse into LicenseData to validate structure
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
|
||||
# not re-serialized through Pydantic. Pydantic may format fields differently
|
||||
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
|
||||
original_payload = decoded.get("payload", {})
|
||||
payload_json = json.dumps(original_payload, sort_keys=True)
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
@@ -88,18 +77,16 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("[verify_license] FAILED: Signature verification failed")
|
||||
logger.error("License signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
|
||||
)
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
except Exception:
|
||||
logger.exception("[verify_license] FAILED: Unexpected error")
|
||||
logger.exception("Unexpected error during license verification")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -21,7 +20,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
@@ -34,7 +33,7 @@ if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
|
||||
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
|
||||
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
|
||||
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
|
||||
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
|
||||
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
|
||||
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
|
||||
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
|
||||
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
|
||||
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
|
||||
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
|
||||
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -97,14 +97,10 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -105,8 +105,6 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
@@ -125,11 +123,9 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
@@ -1457,9 +1456,6 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1467,19 +1463,13 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1508,13 +1498,6 @@ def get_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: Optional[str] = None,
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1531,9 +1514,6 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
@@ -1541,10 +1521,8 @@ def get_oauth_router(
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response: Response,
|
||||
redirect: bool = Query(False),
|
||||
scopes: List[str] = Query(None),
|
||||
) -> Response | OAuth2AuthorizeResponse:
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
@@ -1554,11 +1532,9 @@ def get_oauth_router(
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
@@ -1575,31 +1551,6 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@@ -1649,33 +1600,7 @@ def get_oauth_router(
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat)
|
||||
5. Reporting metrics on things like queue length for monitoring purposes
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
`app_base.py` provides:
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
- mark orphaned index attempts failed
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
|
||||
### Light
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
|
||||
### Heavy
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
|
||||
Generates CSV exports which may take a long time with significant data in Postgres.
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
|
||||
### Monitoring
|
||||
Observability and metrics collections:
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
@@ -26,13 +26,10 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.opensearch.client import (
|
||||
wait_for_opensearch_with_timeout,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -43,7 +40,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.tracing.braintrust_tracing import setup_braintrust_if_creds_available
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -238,9 +234,6 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
# Initialize Braintrust tracing in workers if credentials are available.
|
||||
setup_braintrust_if_creds_available()
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
@@ -523,17 +516,14 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
@@ -121,9 +121,9 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
@@ -133,7 +133,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -98,7 +98,5 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
109
backend/onyx/background/celery/apps/kg_processing.py
Normal file
109
backend/onyx/background/celery/apps/kg_processing.py
Normal file
@@ -0,0 +1,109 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
|
||||
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
@@ -116,7 +116,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -323,6 +323,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -33,16 +32,10 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: (
|
||||
Iterator[list[Document | HierarchyNode]]
|
||||
| Iterator[list[SlimDocument | HierarchyNode]]
|
||||
),
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {
|
||||
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
|
||||
for doc in doc_list
|
||||
}
|
||||
yield {doc.id for doc in doc_list}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user