mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-24 19:25:46 +00:00
Compare commits
8 Commits
craft_chan
...
agent-mess
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9abf96f487 | ||
|
|
67a6266c97 | ||
|
|
1a076f557d | ||
|
|
087f6d8f6a | ||
|
|
040f779b20 | ||
|
|
107809543b | ||
|
|
95fd5f81a4 | ||
|
|
94ef6974d6 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -8,5 +8,4 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
259
.github/workflows/deployment.yml
vendored
259
.github/workflows/deployment.yml
vendored
@@ -26,14 +26,12 @@ jobs:
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
build-backend: ${{ steps.check.outputs.build-backend }}
|
||||
build-backend-craft: ${{ steps.check.outputs.build-backend-craft }}
|
||||
build-model-server: ${{ steps.check.outputs.build-model-server }}
|
||||
is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }}
|
||||
is-stable: ${{ steps.check.outputs.is-stable }}
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-craft-latest: ${{ steps.check.outputs.is-craft-latest }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
@@ -56,20 +54,15 @@ jobs:
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_CRAFT_LATEST=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_BACKEND_CRAFT=false
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == craft-* ]]; then
|
||||
IS_CRAFT_LATEST=true
|
||||
fi
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
@@ -97,12 +90,6 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
# Craft-latest builds backend with Craft enabled
|
||||
if [[ "$IS_CRAFT_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
BUILD_BACKEND=false
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
@@ -126,14 +113,12 @@ jobs:
|
||||
echo "build-web=$BUILD_WEB"
|
||||
echo "build-web-cloud=$BUILD_WEB_CLOUD"
|
||||
echo "build-backend=$BUILD_BACKEND"
|
||||
echo "build-backend-craft=$BUILD_BACKEND_CRAFT"
|
||||
echo "build-model-server=$BUILD_MODEL_SERVER"
|
||||
echo "is-cloud-tag=$IS_CLOUD"
|
||||
echo "is-stable=$IS_STABLE"
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-craft-latest=$IS_CRAFT_LATEST"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
@@ -145,13 +130,13 @@ jobs:
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -170,7 +155,7 @@ jobs:
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -219,7 +204,7 @@ jobs:
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
@@ -392,7 +377,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -465,7 +450,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -603,7 +588,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -684,7 +669,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -827,7 +812,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -899,7 +884,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1018,217 +1003,6 @@ jobs:
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-backend-craft-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-backend-craft-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-craft-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-backend-craft-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend-craft:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-backend-craft-amd64
|
||||
- build-backend-craft-arm64
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-backend-craft
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-craft-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-craft-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-model-server-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
@@ -1248,7 +1022,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1327,7 +1101,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1580,7 +1354,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -1692,20 +1466,17 @@ jobs:
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
- merge-backend
|
||||
- build-backend-craft-amd64
|
||||
- build-backend-craft-arm64
|
||||
- merge-backend-craft
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || (needs.determine-builds.outputs.build-backend-craft == 'true' && (needs.build-backend-craft-amd64.result == 'failure' || needs.build-backend-craft-arm64.result == 'failure' || needs.merge-backend-craft.result == 'failure')) || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
2
.github/workflows/pr-database-tests.yml
vendored
2
.github/workflows/pr-database-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
114
.github/workflows/pr-desktop-build.yml
vendored
114
.github/workflows/pr-desktop-build.yml
vendored
@@ -1,114 +0,0 @@
|
||||
name: Build Desktop App
|
||||
concurrency:
|
||||
group: Build-Desktop-App-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
paths:
|
||||
- "desktop/**"
|
||||
- ".github/workflows/pr-desktop-build.yml"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build-desktop:
|
||||
name: Build Desktop (${{ matrix.platform }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: linux
|
||||
os: ubuntu-latest
|
||||
target: x86_64-unknown-linux-gnu
|
||||
args: "--bundles deb,rpm"
|
||||
# TODO: Fix and enable the macOS build.
|
||||
#- platform: macos
|
||||
# os: macos-latest
|
||||
# target: universal-apple-darwin
|
||||
# args: "--target universal-apple-darwin"
|
||||
# TODO: Fix and enable the Windows build.
|
||||
#- platform: windows
|
||||
# os: windows-latest
|
||||
# target: x86_64-pc-windows-msvc
|
||||
# args: ""
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020
|
||||
with:
|
||||
node-version: 24
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- name: Cache Cargo registry and build
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
desktop/src-tauri/target/
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('desktop/src-tauri/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Install Linux dependencies
|
||||
if: matrix.platform == 'linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
build-essential \
|
||||
libglib2.0-dev \
|
||||
libgirepository1.0-dev \
|
||||
libgtk-3-dev \
|
||||
libjavascriptcoregtk-4.1-dev \
|
||||
libwebkit2gtk-4.1-dev \
|
||||
libayatana-appindicator3-dev \
|
||||
gobject-introspection \
|
||||
pkg-config \
|
||||
curl \
|
||||
xdg-utils
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./desktop
|
||||
run: npm ci
|
||||
|
||||
- name: Build desktop app
|
||||
working-directory: ./desktop
|
||||
run: npx tauri build ${{ matrix.args }}
|
||||
env:
|
||||
TAURI_SIGNING_PRIVATE_KEY: ""
|
||||
TAURI_SIGNING_PRIVATE_KEY_PASSWORD: ""
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02
|
||||
with:
|
||||
name: desktop-build-${{ matrix.platform }}-${{ github.run_id }}
|
||||
path: |
|
||||
desktop/src-tauri/target/release/bundle/
|
||||
retention-days: 7
|
||||
if-no-files-found: ignore
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
3
.github/workflows/pr-helm-chart-testing.yml
vendored
3
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -197,6 +197,7 @@ jobs:
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
|
||||
12
.github/workflows/pr-integration-tests.yml
vendored
12
.github/workflows/pr-integration-tests.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -84,7 +84,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -144,7 +144,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -203,7 +203,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -279,7 +279,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -460,7 +460,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
10
.github/workflows/pr-mit-integration-tests.yml
vendored
10
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -76,7 +76,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -136,7 +136,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -195,7 +195,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -271,7 +271,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
10
.github/workflows/pr-playwright-tests.yml
vendored
10
.github/workflows/pr-playwright-tests.yml
vendored
@@ -66,7 +66,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -127,7 +127,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -188,7 +188,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -254,7 +254,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -488,7 +488,7 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
# uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
|
||||
5
.github/workflows/pr-python-checks.yml
vendored
5
.github/workflows/pr-python-checks.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -50,9 +50,8 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
@@ -65,7 +65,7 @@ env:
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ vars.SF_USERNAME }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
|
||||
@@ -110,9 +110,6 @@ env:
|
||||
# Slack
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
|
||||
# Discord
|
||||
DISCORD_CONNECTOR_BOT_TOKEN: ${{ secrets.DISCORD_CONNECTOR_BOT_TOKEN }}
|
||||
|
||||
# Teams
|
||||
TEAMS_APPLICATION_ID: ${{ secrets.TEAMS_APPLICATION_ID }}
|
||||
TEAMS_DIRECTORY_ID: ${{ secrets.TEAMS_DIRECTORY_ID }}
|
||||
@@ -142,7 +139,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -38,7 +38,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@91fd7d7cf70ae1dee9f4f44e7dfa5d1073fe6623 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -24,11 +24,11 @@ jobs:
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
2
.github/workflows/sync_foss.yml
vendored
2
.github/workflows/sync_foss.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tag-nightly.yml
vendored
2
.github/workflows/tag-nightly.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
4
.github/workflows/zizmor.yml
vendored
4
.github/workflows/zizmor.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,7 +1,6 @@
|
||||
# editors
|
||||
.vscode/*
|
||||
.vscode
|
||||
!/.vscode/env_template.txt
|
||||
!/.vscode/env.web_template.txt
|
||||
!/.vscode/launch.json
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
|
||||
16
.vscode/env.web_template.txt
vendored
16
.vscode/env.web_template.txt
vendored
@@ -1,16 +0,0 @@
|
||||
# Copy this file to .env.web in the .vscode folder.
|
||||
# Fill in the <REPLACE THIS> values as needed
|
||||
# Web Server specific environment variables
|
||||
# Minimal set needed for Next.js dev server
|
||||
|
||||
# Auth
|
||||
AUTH_TYPE=basic
|
||||
DEV_MODE=true
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features.
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you
|
||||
# are using this for local testing/development).
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false
|
||||
|
||||
# Enable Onyx Craft
|
||||
ENABLE_CRAFT=true
|
||||
7
.vscode/env_template.txt
vendored
7
.vscode/env_template.txt
vendored
@@ -6,13 +6,13 @@
|
||||
# processes.
|
||||
|
||||
|
||||
AUTH_TYPE=basic
|
||||
DEV_MODE=true
|
||||
# For local dev, often user Authentication is not needed.
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
|
||||
# Always keep these on for Dev.
|
||||
# Logs model prompts, reasoning, and answer to stdout.
|
||||
LOG_ONYX_MODEL_INTERACTIONS=False
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
@@ -35,6 +35,7 @@ GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper.
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
|
||||
# Python stuff
|
||||
|
||||
5
.vscode/launch.json
vendored
5
.vscode/launch.json
vendored
@@ -25,7 +25,6 @@
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery user_file_processing",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
@@ -87,7 +86,7 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}/web",
|
||||
"runtimeExecutable": "npm",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env.web",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"runtimeArgs": ["run", "dev"],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
@@ -122,6 +121,7 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_ONYX_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1"
|
||||
},
|
||||
@@ -572,6 +572,7 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_ONYX_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
|
||||
@@ -142,13 +142,6 @@ COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervis
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
# This pre-bakes demo data, Python venv, and npm dependencies into the image
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Running Craft template setup at build time..." && \
|
||||
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
|
||||
fi
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""LLMProvider deprecated fields are nullable
|
||||
|
||||
Revision ID: 001984c88745
|
||||
Revises: 01f8e6d95a33
|
||||
Create Date: 2026-02-01 22:24:34.171100
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "001984c88745"
|
||||
down_revision = "01f8e6d95a33"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
# is_default_provider and default_vision_model are already nullable with no server_default
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -1,112 +0,0 @@
|
||||
"""Populate flow mapping data
|
||||
|
||||
Revision ID: 01f8e6d95a33
|
||||
Revises: f220515df7b4
|
||||
Create Date: 2026-01-31 17:37:10.485558
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "01f8e6d95a33"
|
||||
down_revision = "f220515df7b4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add each model config to the conversation flow, setting the global default if it exists
|
||||
# Exclude models that are part of ImageGenerationConfig
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'chat' AS llm_model_flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM image_generation_config igc
|
||||
WHERE igc.model_configuration_id = mc.id
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add models with supports_image_input to the vision flow
|
||||
op.execute(
|
||||
"""
|
||||
INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id)
|
||||
SELECT
|
||||
'vision' AS llm_model_flow_type,
|
||||
COALESCE(
|
||||
(lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name),
|
||||
FALSE
|
||||
) AS is_default,
|
||||
mc.id AS model_configuration_id
|
||||
FROM model_configuration mc
|
||||
LEFT JOIN llm_provider lp
|
||||
ON lp.id = mc.llm_provider_id
|
||||
WHERE mc.supports_image_input IS TRUE;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Populate vision defaults from model_flow
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_vision_provider = TRUE,
|
||||
default_vision_model = mc.name
|
||||
FROM llm_model_flow mf
|
||||
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
|
||||
WHERE mf.llm_model_flow_type = 'vision'
|
||||
AND mf.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# Populate conversation defaults from model_flow
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET
|
||||
is_default_provider = TRUE,
|
||||
default_model_name = mc.name
|
||||
FROM llm_model_flow mf
|
||||
JOIN model_configuration mc ON mc.id = mf.model_configuration_id
|
||||
WHERE mf.llm_model_flow_type = 'chat'
|
||||
AND mf.is_default = TRUE
|
||||
AND mc.llm_provider_id = lp.id;
|
||||
"""
|
||||
)
|
||||
|
||||
# For providers that have conversation flow mappings but aren't the default,
|
||||
# we still need a default_model_name (it was NOT NULL originally)
|
||||
# Pick the first visible model or any model for that provider
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE llm_provider AS lp
|
||||
SET default_model_name = (
|
||||
SELECT mc.name
|
||||
FROM model_configuration mc
|
||||
JOIN llm_model_flow mf ON mf.model_configuration_id = mc.id
|
||||
WHERE mc.llm_provider_id = lp.id
|
||||
AND mf.llm_model_flow_type = 'chat'
|
||||
ORDER BY mc.is_visible DESC, mc.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE lp.default_model_name IS NULL;
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete all model_flow entries (reverse the inserts from upgrade)
|
||||
op.execute("DELETE FROM llm_model_flow;")
|
||||
@@ -10,6 +10,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
down_revision = "f17bf3b0d9f1"
|
||||
@@ -64,7 +66,7 @@ def upgrade() -> None:
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(20),
|
||||
server_default=str(NUM_POSTPROCESSED_RESULTS),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
"""remove reranking from search_settings
|
||||
|
||||
Revision ID: 78ebc66946a0
|
||||
Revises: 849b21c732f8
|
||||
Create Date: 2026-01-28
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "78ebc66946a0"
|
||||
down_revision = "849b21c732f8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("search_settings", "disable_rerank_for_streaming")
|
||||
op.drop_column("search_settings", "rerank_model_name")
|
||||
op.drop_column("search_settings", "rerank_provider_type")
|
||||
op.drop_column("search_settings", "rerank_api_key")
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
||||
op.drop_column("search_settings", "num_rerank")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"disable_rerank_for_streaming",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(20),
|
||||
),
|
||||
)
|
||||
@@ -1,349 +0,0 @@
|
||||
"""hierarchy_nodes_v1
|
||||
|
||||
Revision ID: 81c22b1e2e78
|
||||
Revises: 72aa7de2e5cf
|
||||
Create Date: 2026-01-13 18:10:01.021451
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "81c22b1e2e78"
|
||||
down_revision = "72aa7de2e5cf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Human-readable display names for each source
|
||||
SOURCE_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ingestion_api": "Ingestion API",
|
||||
"slack": "Slack",
|
||||
"web": "Web",
|
||||
"google_drive": "Google Drive",
|
||||
"gmail": "Gmail",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"gitbook": "GitBook",
|
||||
"gitlab": "GitLab",
|
||||
"guru": "Guru",
|
||||
"bookstack": "BookStack",
|
||||
"outline": "Outline",
|
||||
"confluence": "Confluence",
|
||||
"jira": "Jira",
|
||||
"slab": "Slab",
|
||||
"productboard": "Productboard",
|
||||
"file": "File",
|
||||
"coda": "Coda",
|
||||
"notion": "Notion",
|
||||
"zulip": "Zulip",
|
||||
"linear": "Linear",
|
||||
"hubspot": "HubSpot",
|
||||
"document360": "Document360",
|
||||
"gong": "Gong",
|
||||
"google_sites": "Google Sites",
|
||||
"zendesk": "Zendesk",
|
||||
"loopio": "Loopio",
|
||||
"dropbox": "Dropbox",
|
||||
"sharepoint": "SharePoint",
|
||||
"teams": "Teams",
|
||||
"salesforce": "Salesforce",
|
||||
"discourse": "Discourse",
|
||||
"axero": "Axero",
|
||||
"clickup": "ClickUp",
|
||||
"mediawiki": "MediaWiki",
|
||||
"wikipedia": "Wikipedia",
|
||||
"asana": "Asana",
|
||||
"s3": "S3",
|
||||
"r2": "R2",
|
||||
"google_cloud_storage": "Google Cloud Storage",
|
||||
"oci_storage": "OCI Storage",
|
||||
"xenforo": "XenForo",
|
||||
"not_applicable": "Not Applicable",
|
||||
"discord": "Discord",
|
||||
"freshdesk": "Freshdesk",
|
||||
"fireflies": "Fireflies",
|
||||
"egnyte": "Egnyte",
|
||||
"airtable": "Airtable",
|
||||
"highspot": "Highspot",
|
||||
"drupal_wiki": "Drupal Wiki",
|
||||
"imap": "IMAP",
|
||||
"bitbucket": "Bitbucket",
|
||||
"testrail": "TestRail",
|
||||
"mock_connector": "Mock Connector",
|
||||
"user_file": "User File",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create hierarchy_node table
|
||||
op.create_table(
|
||||
"hierarchy_node",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("raw_node_id", sa.String(), nullable=False),
|
||||
sa.Column("display_name", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("node_type", sa.String(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("parent_id", sa.Integer(), nullable=True),
|
||||
# Permission fields - same pattern as Document table
|
||||
sa.Column(
|
||||
"external_user_emails",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"external_user_group_ids",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
# When document is deleted, just unlink (node can exist without document)
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
|
||||
# When parent node is deleted, orphan children (cleanup via pruning)
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
|
||||
)
|
||||
|
||||
# Add partial unique index to ensure only one SOURCE-type node per source
|
||||
# This prevents duplicate source root nodes from being created
|
||||
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
|
||||
ON hierarchy_node (source)
|
||||
WHERE node_type = 'SOURCE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create hierarchy_fetch_attempt table
|
||||
op.create_table(
|
||||
"hierarchy_fetch_attempt",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("full_exception_trace", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created",
|
||||
"hierarchy_fetch_attempt",
|
||||
["time_created"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair",
|
||||
"hierarchy_fetch_attempt",
|
||||
["connector_credential_pair_id"],
|
||||
)
|
||||
|
||||
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
|
||||
# We insert these so every existing document can have a parent hierarchy node
|
||||
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
|
||||
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
|
||||
# SOURCE nodes are always public since they're just categorical roots.
|
||||
for source in DocumentSource:
|
||||
source_name = (
|
||||
source.name
|
||||
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
|
||||
source_value = source.value # e.g., 'google_drive' - the raw_node_id
|
||||
display_name = SOURCE_DISPLAY_NAMES.get(
|
||||
source_value, source_value.replace("_", " ").title()
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
|
||||
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
|
||||
ON CONFLICT (raw_node_id, source) DO NOTHING
|
||||
"""
|
||||
).bindparams(
|
||||
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
|
||||
display_name=display_name,
|
||||
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Add parent_hierarchy_node_id column to document table
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# When hierarchy node is deleted, just unlink the document (SET NULL)
|
||||
op.create_foreign_key(
|
||||
"fk_document_parent_hierarchy_node",
|
||||
"document",
|
||||
"hierarchy_node",
|
||||
["parent_hierarchy_node_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_document_parent_hierarchy_node_id",
|
||||
"document",
|
||||
["parent_hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
|
||||
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
|
||||
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
|
||||
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE document d
|
||||
SET parent_hierarchy_node_id = hn.id
|
||||
FROM (
|
||||
-- Get the source for each document (pick MIN connector_id for determinism)
|
||||
SELECT DISTINCT ON (dbcc.id)
|
||||
dbcc.id as doc_id,
|
||||
c.source as source
|
||||
FROM document_by_connector_credential_pair dbcc
|
||||
JOIN connector c ON dbcc.connector_id = c.id
|
||||
ORDER BY dbcc.id, dbcc.connector_id
|
||||
) doc_source
|
||||
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
|
||||
WHERE d.id = doc_source.doc_id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create the persona__hierarchy_node association table
|
||||
op.create_table(
|
||||
"persona__hierarchy_node",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups
|
||||
op.create_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
"persona__hierarchy_node",
|
||||
["hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# Create the persona__document association table for attaching individual
|
||||
# documents directly to assistants
|
||||
op.create_table(
|
||||
"persona__document",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "document_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups by document_id
|
||||
op.create_index(
|
||||
"ix_persona__document_document_id",
|
||||
"persona__document",
|
||||
["document_id"],
|
||||
)
|
||||
|
||||
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove last_time_hierarchy_fetch from connector_credential_pair
|
||||
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
|
||||
|
||||
# Drop persona__document table
|
||||
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
|
||||
op.drop_table("persona__document")
|
||||
|
||||
# Drop persona__hierarchy_node table
|
||||
op.drop_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
table_name="persona__hierarchy_node",
|
||||
)
|
||||
op.drop_table("persona__hierarchy_node")
|
||||
|
||||
# Remove parent_hierarchy_node_id from document
|
||||
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
|
||||
op.drop_constraint(
|
||||
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("document", "parent_hierarchy_node_id")
|
||||
|
||||
# Drop hierarchy_fetch_attempt table
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_table("hierarchy_fetch_attempt")
|
||||
|
||||
# Drop hierarchy_node table
|
||||
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
|
||||
op.drop_table("hierarchy_node")
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add demo_data_enabled to build_session
|
||||
|
||||
Revision ID: 849b21c732f8
|
||||
Revises: 81c22b1e2e78
|
||||
Create Date: 2026-01-28 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "849b21c732f8"
|
||||
down_revision = "81c22b1e2e78"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"demo_data_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("true"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "demo_data_enabled")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""add processing_duration_seconds to chat_message
|
||||
|
||||
Revision ID: 9d1543a37106
|
||||
Revises: cbc03e08d0f3
|
||||
Revises: 72aa7de2e5cf
|
||||
Create Date: 2026-01-21 11:42:18.546188
|
||||
|
||||
"""
|
||||
@@ -11,7 +11,7 @@ import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9d1543a37106"
|
||||
down_revision = "cbc03e08d0f3"
|
||||
down_revision = "72aa7de2e5cf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Persona new default model configuration id column
|
||||
|
||||
Revision ID: be87a654d5af
|
||||
Revises: e7f8a9b0c1d2
|
||||
Create Date: 2026-01-30 11:14:17.306275
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "be87a654d5af"
|
||||
down_revision = "e7f8a9b0c1d2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("default_model_configuration_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_persona_default_model_configuration_id",
|
||||
"persona",
|
||||
"model_configuration",
|
||||
["default_model_configuration_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_persona_default_model_configuration_id", "persona", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("persona", "default_model_configuration_id")
|
||||
@@ -1,128 +0,0 @@
|
||||
"""add_opensearch_migration_tables
|
||||
|
||||
Revision ID: cbc03e08d0f3
|
||||
Revises: be87a654d5af
|
||||
Create Date: 2026-01-31 17:00:45.176604
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "cbc03e08d0f3"
|
||||
down_revision = "be87a654d5af"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create opensearch_document_migration_record table.
|
||||
op.create_table(
|
||||
"opensearch_document_migration_record",
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False, server_default="pending"),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("attempts_count", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("last_attempt_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("document_id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
# 2. Create indices.
|
||||
op.create_index(
|
||||
"ix_opensearch_document_migration_record_status",
|
||||
"opensearch_document_migration_record",
|
||||
["status"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_opensearch_document_migration_record_attempts_count",
|
||||
"opensearch_document_migration_record",
|
||||
["attempts_count"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_opensearch_document_migration_record_created_at",
|
||||
"opensearch_document_migration_record",
|
||||
["created_at"],
|
||||
)
|
||||
|
||||
# 3. Create opensearch_tenant_migration_record table (singleton).
|
||||
op.create_table(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"document_migration_record_table_population_status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column(
|
||||
"num_times_observed_no_additional_docs_to_populate_migration_table",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
sa.Column(
|
||||
"overall_document_migration_status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="pending",
|
||||
),
|
||||
sa.Column(
|
||||
"num_times_observed_no_additional_docs_to_migrate",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default="0",
|
||||
),
|
||||
sa.Column(
|
||||
"last_updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
# 4. Create unique index on constant to enforce singleton pattern.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE UNIQUE INDEX idx_opensearch_tenant_migration_singleton
|
||||
ON opensearch_tenant_migration_record ((true))
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop opensearch_tenant_migration_record.
|
||||
op.drop_index(
|
||||
"idx_opensearch_tenant_migration_singleton",
|
||||
table_name="opensearch_tenant_migration_record",
|
||||
)
|
||||
op.drop_table("opensearch_tenant_migration_record")
|
||||
|
||||
# Drop opensearch_document_migration_record.
|
||||
op.drop_index(
|
||||
"ix_opensearch_document_migration_record_created_at",
|
||||
table_name="opensearch_document_migration_record",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_opensearch_document_migration_record_attempts_count",
|
||||
table_name="opensearch_document_migration_record",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_opensearch_document_migration_record_status",
|
||||
table_name="opensearch_document_migration_record",
|
||||
)
|
||||
op.drop_table("opensearch_document_migration_record")
|
||||
@@ -1,125 +0,0 @@
|
||||
"""create_anonymous_user
|
||||
|
||||
This migration creates a permanent anonymous user in the database.
|
||||
When anonymous access is enabled, unauthenticated requests will use this user
|
||||
instead of returning user_id=NULL.
|
||||
|
||||
Revision ID: e7f8a9b0c1d2
|
||||
Revises: f7ca3e2f45d9
|
||||
Create Date: 2026-01-15 14:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e7f8a9b0c1d2"
|
||||
down_revision = "f7ca3e2f45d9"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Must match constants in onyx/configs/constants.py file
|
||||
ANONYMOUS_USER_UUID = "00000000-0000-0000-0000-000000000002"
|
||||
ANONYMOUS_USER_EMAIL = "anonymous@onyx.app"
|
||||
|
||||
# Tables with user_id foreign key that may need migration
|
||||
TABLES_WITH_USER_ID = [
|
||||
"chat_session",
|
||||
"credential",
|
||||
"document_set",
|
||||
"persona",
|
||||
"tool",
|
||||
"notification",
|
||||
"inputprompt",
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Create the anonymous user for anonymous access feature.
|
||||
Also migrates any remaining user_id=NULL records to the anonymous user.
|
||||
"""
|
||||
connection = op.get_bind()
|
||||
|
||||
# Create the anonymous user (using ON CONFLICT to be idempotent)
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role)
|
||||
VALUES (:id, :email, :hashed_password, :is_active, :is_superuser, :is_verified, :role)
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": ANONYMOUS_USER_UUID,
|
||||
"email": ANONYMOUS_USER_EMAIL,
|
||||
"hashed_password": "", # Empty password - user cannot log in directly
|
||||
"is_active": True, # Active so it can be used for anonymous access
|
||||
"is_superuser": False,
|
||||
"is_verified": True, # Verified since no email verification needed
|
||||
"role": "LIMITED", # Anonymous users have limited role to restrict access
|
||||
},
|
||||
)
|
||||
|
||||
# Migrate any remaining user_id=NULL records to anonymous user
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
# Exclude public credential (id=0) which must remain user_id=NULL
|
||||
# Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL
|
||||
# Exclude builtin personas (builtin_persona=True) which must remain user_id=NULL
|
||||
# Exclude system input prompts (is_public=True with user_id=NULL) which must remain user_id=NULL
|
||||
if table == "credential":
|
||||
condition = "user_id IS NULL AND id != 0"
|
||||
elif table == "tool":
|
||||
condition = "user_id IS NULL AND in_code_tool_id IS NULL"
|
||||
elif table == "persona":
|
||||
condition = "user_id IS NULL AND builtin_persona = false"
|
||||
elif table == "inputprompt":
|
||||
condition = "user_id IS NULL AND is_public = false"
|
||||
else:
|
||||
condition = "user_id IS NULL"
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE "{table}"
|
||||
SET user_id = :user_id
|
||||
WHERE {condition}
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Updated {result.rowcount} rows in {table} to anonymous user")
|
||||
except Exception as e:
|
||||
print(f"Skipping {table}: {e}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Set anonymous user's records back to NULL and delete the anonymous user.
|
||||
"""
|
||||
connection = op.get_bind()
|
||||
|
||||
# Set records back to NULL
|
||||
for table in TABLES_WITH_USER_ID:
|
||||
try:
|
||||
connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE "{table}"
|
||||
SET user_id = NULL
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Delete the anonymous user
|
||||
connection.execute(
|
||||
sa.text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{"user_id": ANONYMOUS_USER_UUID},
|
||||
)
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Add flow mapping table
|
||||
|
||||
Revision ID: f220515df7b4
|
||||
Revises: cbc03e08d0f3
|
||||
Create Date: 2026-01-30 12:21:24.955922
|
||||
|
||||
"""
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f220515df7b4"
|
||||
down_revision = "9d1543a37106"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_model_flow",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"llm_model_flow_type",
|
||||
sa.Enum(LLMModelFlowType, name="llmmodelflowtype", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_default", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"llm_model_flow_type",
|
||||
"model_configuration_id",
|
||||
name="uq_model_config_per_llm_model_flow_type",
|
||||
),
|
||||
)
|
||||
|
||||
# Partial unique index so that there is at most one default for each flow type
|
||||
op.create_index(
|
||||
"ix_one_default_per_llm_model_flow",
|
||||
"llm_model_flow",
|
||||
["llm_model_flow_type"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("is_default IS TRUE"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the llm_model_flow table (index is dropped automatically with table)
|
||||
op.drop_table("llm_model_flow")
|
||||
@@ -1,281 +0,0 @@
|
||||
"""migrate_no_auth_data_to_placeholder
|
||||
|
||||
This migration handles the transition from AUTH_TYPE=disabled to requiring
|
||||
authentication. It creates a placeholder user and assigns all data that was
|
||||
created without a user (user_id=NULL) to this placeholder.
|
||||
|
||||
A database trigger is installed that automatically transfers all data from
|
||||
the placeholder user to the first real user who registers, then drops itself.
|
||||
|
||||
Revision ID: f7ca3e2f45d9
|
||||
Revises: 78ebc66946a0
|
||||
Create Date: 2026-01-15 12:49:53.802741
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7ca3e2f45d9"
|
||||
down_revision = "78ebc66946a0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Must match constants in onyx/configs/constants.py file
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
NO_AUTH_PLACEHOLDER_USER_EMAIL = "no-auth-placeholder@onyx.app"
|
||||
|
||||
# Trigger and function names
|
||||
TRIGGER_NAME = "trg_migrate_no_auth_data"
|
||||
FUNCTION_NAME = "migrate_no_auth_data_to_user"
|
||||
|
||||
# Trigger function that migrates data from placeholder to first real user
|
||||
MIGRATE_NO_AUTH_TRIGGER_FUNCTION = f"""
|
||||
CREATE OR REPLACE FUNCTION {FUNCTION_NAME}()
|
||||
RETURNS TRIGGER AS $$
|
||||
DECLARE
|
||||
placeholder_uuid UUID := '00000000-0000-0000-0000-000000000001'::uuid;
|
||||
anonymous_uuid UUID := '00000000-0000-0000-0000-000000000002'::uuid;
|
||||
placeholder_row RECORD;
|
||||
schema_name TEXT;
|
||||
BEGIN
|
||||
-- Skip if this is the placeholder user being inserted
|
||||
IF NEW.id = placeholder_uuid THEN
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
|
||||
-- Skip if this is the anonymous user being inserted (not a real user)
|
||||
IF NEW.id = anonymous_uuid THEN
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
|
||||
-- Skip if the new user is not active
|
||||
IF NEW.is_active = FALSE THEN
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
|
||||
-- Get current schema for self-cleanup
|
||||
schema_name := current_schema();
|
||||
|
||||
-- Try to lock the placeholder user row with FOR UPDATE SKIP LOCKED
|
||||
-- This ensures only one concurrent transaction can proceed with migration
|
||||
-- SKIP LOCKED means if another transaction has the lock, we skip (don't wait)
|
||||
SELECT id INTO placeholder_row
|
||||
FROM "user"
|
||||
WHERE id = placeholder_uuid
|
||||
FOR UPDATE SKIP LOCKED;
|
||||
|
||||
IF NOT FOUND THEN
|
||||
-- Either placeholder doesn't exist or another transaction has it locked
|
||||
-- Either way, drop the trigger and return without making admin
|
||||
EXECUTE format('DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON %I."user"', schema_name);
|
||||
EXECUTE format('DROP FUNCTION IF EXISTS %I.{FUNCTION_NAME}()', schema_name);
|
||||
RETURN NULL;
|
||||
END IF;
|
||||
|
||||
-- We have exclusive lock on placeholder - proceed with migration
|
||||
-- The INSERT has already completed (AFTER INSERT), so NEW.id exists in the table
|
||||
|
||||
-- Migrate chat_session
|
||||
UPDATE "chat_session" SET user_id = NEW.id WHERE user_id = placeholder_uuid;
|
||||
|
||||
-- Migrate credential (exclude public credential id=0)
|
||||
UPDATE "credential" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND id != 0;
|
||||
|
||||
-- Migrate document_set
|
||||
UPDATE "document_set" SET user_id = NEW.id WHERE user_id = placeholder_uuid;
|
||||
|
||||
-- Migrate persona (exclude builtin personas)
|
||||
UPDATE "persona" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND builtin_persona = FALSE;
|
||||
|
||||
-- Migrate tool (exclude builtin tools)
|
||||
UPDATE "tool" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND in_code_tool_id IS NULL;
|
||||
|
||||
-- Migrate notification
|
||||
UPDATE "notification" SET user_id = NEW.id WHERE user_id = placeholder_uuid;
|
||||
|
||||
-- Migrate inputprompt (exclude system/public prompts)
|
||||
UPDATE "inputprompt" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND is_public = FALSE;
|
||||
|
||||
-- Make the new user an admin (they had admin access in no-auth mode)
|
||||
-- In AFTER INSERT trigger, we must UPDATE the row since it already exists
|
||||
UPDATE "user" SET role = 'ADMIN' WHERE id = NEW.id;
|
||||
|
||||
-- Delete the placeholder user (we hold the lock so this is safe)
|
||||
DELETE FROM "user" WHERE id = placeholder_uuid;
|
||||
|
||||
-- Drop the trigger and function (self-cleanup)
|
||||
EXECUTE format('DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON %I."user"', schema_name);
|
||||
EXECUTE format('DROP FUNCTION IF EXISTS %I.{FUNCTION_NAME}()', schema_name);
|
||||
|
||||
RETURN NULL;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
"""
|
||||
|
||||
MIGRATE_NO_AUTH_TRIGGER = f"""
|
||||
CREATE TRIGGER {TRIGGER_NAME}
|
||||
AFTER INSERT ON "user"
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION {FUNCTION_NAME}();
|
||||
"""
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""
|
||||
Create a placeholder user and assign all NULL user_id records to it.
|
||||
Install a trigger that migrates data to the first real user and self-destructs.
|
||||
Only runs if AUTH_TYPE is currently disabled/none.
|
||||
|
||||
Skipped in multi-tenant mode - each tenant starts fresh with no legacy data.
|
||||
"""
|
||||
# Skip in multi-tenant mode - this migration handles single-tenant
|
||||
# AUTH_TYPE=disabled -> auth transitions only
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Only run if AUTH_TYPE is currently disabled/none
|
||||
# If they've already switched to auth-enabled, NULL data is stale anyway
|
||||
auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if auth_type not in ("disabled", "none", ""):
|
||||
print(f"AUTH_TYPE is '{auth_type}', not disabled. Skipping migration.")
|
||||
return
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if there are any NULL user_id records that need migration
|
||||
tables_to_check = [
|
||||
"chat_session",
|
||||
"credential",
|
||||
"document_set",
|
||||
"persona",
|
||||
"tool",
|
||||
"notification",
|
||||
"inputprompt",
|
||||
]
|
||||
|
||||
has_null_records = False
|
||||
for table in tables_to_check:
|
||||
try:
|
||||
result = connection.execute(
|
||||
sa.text(f'SELECT 1 FROM "{table}" WHERE user_id IS NULL LIMIT 1')
|
||||
)
|
||||
if result.fetchone():
|
||||
has_null_records = True
|
||||
break
|
||||
except Exception:
|
||||
# Table might not exist
|
||||
pass
|
||||
|
||||
if not has_null_records:
|
||||
return
|
||||
|
||||
# Create the placeholder user
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role)
|
||||
VALUES (:id, :email, :hashed_password, :is_active, :is_superuser, :is_verified, :role)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
"email": NO_AUTH_PLACEHOLDER_USER_EMAIL,
|
||||
"hashed_password": "", # Empty password - user cannot log in
|
||||
"is_active": False, # Inactive - user cannot log in
|
||||
"is_superuser": False,
|
||||
"is_verified": False,
|
||||
"role": "BASIC",
|
||||
},
|
||||
)
|
||||
|
||||
# Assign NULL user_id records to the placeholder user
|
||||
for table in tables_to_check:
|
||||
try:
|
||||
# Base condition for all tables
|
||||
condition = "user_id IS NULL"
|
||||
# Exclude public credential (id=0) which must remain user_id=NULL
|
||||
if table == "credential":
|
||||
condition += " AND id != 0"
|
||||
# Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL
|
||||
elif table == "tool":
|
||||
condition += " AND in_code_tool_id IS NULL"
|
||||
# Exclude builtin personas which must remain user_id=NULL
|
||||
elif table == "persona":
|
||||
condition += " AND builtin_persona = FALSE"
|
||||
# Exclude system/public input prompts which must remain user_id=NULL
|
||||
elif table == "inputprompt":
|
||||
condition += " AND is_public = FALSE"
|
||||
result = connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE "{table}"
|
||||
SET user_id = :user_id
|
||||
WHERE {condition}
|
||||
"""
|
||||
),
|
||||
{"user_id": NO_AUTH_PLACEHOLDER_USER_UUID},
|
||||
)
|
||||
if result.rowcount > 0:
|
||||
print(f"Updated {result.rowcount} rows in {table}")
|
||||
except Exception as e:
|
||||
print(f"Skipping {table}: {e}")
|
||||
|
||||
# Install the trigger function and trigger for automatic migration on first user registration
|
||||
connection.execute(sa.text(MIGRATE_NO_AUTH_TRIGGER_FUNCTION))
|
||||
connection.execute(sa.text(MIGRATE_NO_AUTH_TRIGGER))
|
||||
print("Installed trigger for automatic data migration on first user registration")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""
|
||||
Drop trigger and function, set placeholder user's records back to NULL,
|
||||
and delete the placeholder user.
|
||||
"""
|
||||
# Skip in multi-tenant mode for consistency with upgrade
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Drop trigger and function if they exist (they may have already self-destructed)
|
||||
connection.execute(sa.text(f'DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON "user"'))
|
||||
connection.execute(sa.text(f"DROP FUNCTION IF EXISTS {FUNCTION_NAME}()"))
|
||||
|
||||
tables_to_update = [
|
||||
"chat_session",
|
||||
"credential",
|
||||
"document_set",
|
||||
"persona",
|
||||
"tool",
|
||||
"notification",
|
||||
"inputprompt",
|
||||
]
|
||||
|
||||
# Set records back to NULL
|
||||
for table in tables_to_update:
|
||||
try:
|
||||
connection.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE "{table}"
|
||||
SET user_id = NULL
|
||||
WHERE user_id = :user_id
|
||||
"""
|
||||
),
|
||||
{"user_id": NO_AUTH_PLACEHOLDER_USER_UUID},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Delete the placeholder user
|
||||
connection.execute(
|
||||
sa.text('DELETE FROM "user" WHERE id = :user_id'),
|
||||
{"user_id": NO_AUTH_PLACEHOLDER_USER_UUID},
|
||||
)
|
||||
@@ -116,7 +116,7 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
@@ -124,16 +124,13 @@ def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
is_anonymous = user.is_anonymous
|
||||
db_user_groups = (
|
||||
[] if is_anonymous else fetch_user_groups_for_user(db_session, user.id)
|
||||
)
|
||||
db_user_groups = fetch_user_groups_for_user(db_session, user.id) if user else []
|
||||
prefixed_user_groups = [
|
||||
prefix_user_group(db_user_group.name) for db_user_group in db_user_groups
|
||||
]
|
||||
|
||||
db_external_groups = (
|
||||
[] if is_anonymous else fetch_external_groups_for_user(db_session, user.id)
|
||||
fetch_external_groups_for_user(db_session, user.id) if user else []
|
||||
)
|
||||
prefixed_external_groups = [
|
||||
prefix_external_group(db_external_group.external_user_group_id)
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def _get_user_external_group_ids(db_session: Session, user: User) -> list[str]:
|
||||
if not user:
|
||||
return []
|
||||
external_groups = fetch_external_groups_for_user(db_session, user.id)
|
||||
return [external_group.external_user_group_id for external_group in external_groups]
|
||||
@@ -33,8 +33,8 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
async def current_cloud_superuser(
|
||||
request: Request,
|
||||
user: User = Depends(current_admin_user),
|
||||
) -> User:
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> User | None:
|
||||
api_key = request.headers.get("Authorization", "").replace("Bearer ", "")
|
||||
if api_key != SUPER_CLOUD_API_KEY:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
@@ -25,7 +25,6 @@ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
@@ -56,9 +55,6 @@ from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.hierarchy import (
|
||||
update_hierarchy_node_permissions as db_update_hierarchy_node_permissions,
|
||||
)
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt
|
||||
from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt
|
||||
@@ -641,24 +637,17 @@ def connector_permission_sync_generator_task(
|
||||
),
|
||||
stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER),
|
||||
)
|
||||
def element_update_permissions(
|
||||
def document_update_permissions(
|
||||
tenant_id: str,
|
||||
permissions: ElementExternalAccess,
|
||||
permissions: DocExternalAccess,
|
||||
source_type_str: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
"""Update permissions for a document or hierarchy node."""
|
||||
start = time.monotonic()
|
||||
external_access = permissions.external_access
|
||||
|
||||
# Determine element type and identifier for logging
|
||||
if isinstance(permissions, DocExternalAccess):
|
||||
element_id = permissions.doc_id
|
||||
element_type = "doc"
|
||||
else:
|
||||
element_id = permissions.raw_node_id
|
||||
element_type = "node"
|
||||
doc_id = permissions.doc_id
|
||||
external_access = permissions.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
@@ -668,57 +657,39 @@ def element_update_permissions(
|
||||
emails=list(external_access.external_user_emails),
|
||||
continue_on_error=True,
|
||||
)
|
||||
# Then upsert the document's external permissions
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_type_str),
|
||||
)
|
||||
|
||||
if isinstance(permissions, DocExternalAccess):
|
||||
# Document permission update
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
doc_id=permissions.doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_type_str),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[permissions.doc_id],
|
||||
)
|
||||
else:
|
||||
# Hierarchy node permission update
|
||||
db_update_hierarchy_node_permissions(
|
||||
db_session=db_session,
|
||||
raw_node_id=permissions.raw_node_id,
|
||||
source=DocumentSource(permissions.source),
|
||||
is_public=external_access.is_public,
|
||||
external_user_emails=(
|
||||
list(external_access.external_user_emails)
|
||||
if external_access.external_user_emails
|
||||
else None
|
||||
),
|
||||
external_user_group_ids=(
|
||||
list(external_access.external_user_group_ids)
|
||||
if external_access.external_user_group_ids
|
||||
else None
|
||||
),
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"{element_type}={element_id} "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc={doc_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"element_update_permissions exceptioned: {element_type}={element_id}, {connector_id=} {credential_id=}"
|
||||
f"document_update_permissions exceptioned: "
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
raise e
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"element_update_permissions completed: {element_type}={element_id}, {connector_id=} {credential_id=}"
|
||||
f"document_update_permissions completed: connector_id={connector_id} doc={doc_id}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -122,9 +122,6 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
@@ -136,9 +133,3 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
# Used when MULTI_TENANT=false (self-hosted mode)
|
||||
CLOUD_DATA_PLANE_URL = os.environ.get(
|
||||
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
|
||||
)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Constants for license enforcement.
|
||||
|
||||
This file is the single source of truth for:
|
||||
1. Paths that bypass license enforcement (always accessible)
|
||||
2. Paths that require an EE license (EE-only features)
|
||||
|
||||
Import these constants in both production code and tests to ensure consistency.
|
||||
"""
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /billing - Unified billing API
|
||||
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
|
||||
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
|
||||
# /manage/users, /users - User management (needed for seat limit resolution)
|
||||
# /notifications - Needed for UI to load properly
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
# Billing endpoints (unified API for both MT and self-hosted)
|
||||
"/billing",
|
||||
"/admin/billing",
|
||||
# Proxy endpoints for self-hosted billing (no tenant context)
|
||||
"/proxy",
|
||||
# Legacy tenant billing endpoints (kept for backwards compatibility)
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
# User management - needed to remove users when seat limit exceeded
|
||||
"/manage/users",
|
||||
"/manage/admin/users",
|
||||
"/manage/admin/valid-domains",
|
||||
"/manage/admin/deactivate-user",
|
||||
"/manage/admin/delete-user",
|
||||
"/users",
|
||||
# Notifications - needed for UI to load properly
|
||||
"/notifications",
|
||||
}
|
||||
)
|
||||
|
||||
# EE-only paths that require a valid license.
|
||||
# Users without a license (community edition) cannot access these.
|
||||
# These are blocked even when user has never subscribed (no license).
|
||||
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
# User groups and access control
|
||||
"/manage/admin/user-group",
|
||||
# Analytics and reporting
|
||||
"/analytics",
|
||||
# Query history (admin chat session endpoints)
|
||||
"/admin/chat-sessions",
|
||||
"/admin/chat-session-history",
|
||||
"/admin/query-history",
|
||||
# Usage reporting/export
|
||||
"/admin/usage-report",
|
||||
# Standard answers (canned responses)
|
||||
"/manage/admin/standard-answer",
|
||||
# Token rate limits
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
}
|
||||
)
|
||||
@@ -334,9 +334,11 @@ def fetch_assistant_unique_users_total(
|
||||
# Users can view assistant stats if they created the persona,
|
||||
# or if they are an admin
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User, assistant_id: int
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
if user.role == UserRole.ADMIN:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
"""EE version of hierarchy node access control.
|
||||
|
||||
This module provides permission-aware hierarchy node access for Enterprise Edition.
|
||||
It filters hierarchy nodes based on user email and external group membership.
|
||||
"""
|
||||
|
||||
from sqlalchemy import any_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
|
||||
A user can access a hierarchy node if any of the following are true:
|
||||
- The node is marked as public (is_public=True)
|
||||
- The user's email is in the node's external_user_emails list
|
||||
- Any of the user's external group IDs overlap with the node's external_user_group_ids
|
||||
"""
|
||||
access_filters: list[ColumnElement[bool]] = [HierarchyNode.is_public.is_(True)]
|
||||
if user_email:
|
||||
access_filters.append(any_(HierarchyNode.external_user_emails) == user_email)
|
||||
if external_group_ids:
|
||||
access_filters.append(
|
||||
HierarchyNode.external_user_group_ids.overlap(
|
||||
postgresql.array(external_group_ids)
|
||||
)
|
||||
)
|
||||
return or_(*access_filters)
|
||||
|
||||
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
EE version: Returns hierarchy nodes filtered by user permissions.
|
||||
|
||||
A user can access a hierarchy node if any of the following are true:
|
||||
- The node is marked as public (is_public=True)
|
||||
- The user's email is in the node's external_user_emails list
|
||||
- Any of the user's external group IDs overlap with the node's external_user_group_ids
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
source: Document source type
|
||||
user_email: User's email for permission checking
|
||||
external_group_ids: User's external group IDs for permission checking
|
||||
|
||||
Returns:
|
||||
List of HierarchyNode objects the user has access to
|
||||
"""
|
||||
stmt = select(HierarchyNode).where(HierarchyNode.source == source)
|
||||
stmt = stmt.where(_build_hierarchy_access_filter(user_email, external_group_ids))
|
||||
stmt = stmt.order_by(HierarchyNode.display_name)
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import NamedTuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -10,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -25,13 +23,6 @@ LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
class SeatAvailabilityResult(NamedTuple):
|
||||
"""Result of a seat availability check."""
|
||||
|
||||
available: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -104,30 +95,23 @@ def delete_license(db_session: Session) -> bool:
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage directly from database.
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -227,10 +211,10 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.set(
|
||||
redis_client.setex(
|
||||
LICENSE_METADATA_KEY,
|
||||
LICENSE_CACHE_TTL_SECONDS,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
|
||||
@@ -292,43 +276,3 @@ def get_license_metadata(
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
|
||||
|
||||
def check_seat_availability(
|
||||
db_session: Session,
|
||||
seats_needed: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatAvailabilityResult:
|
||||
"""
|
||||
Check if there are enough seats available to add users.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
seats_needed: Number of seats needed (default 1)
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
SeatAvailabilityResult with available=True if seats are available,
|
||||
or available=False with error_message if limit would be exceeded.
|
||||
Returns available=True if no license exists (self-hosted = unlimited).
|
||||
"""
|
||||
metadata = get_license_metadata(db_session, tenant_id)
|
||||
|
||||
# No license = no enforcement (self-hosted without license)
|
||||
if metadata is None:
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
# Calculate current usage directly from DB (not cache) for accuracy
|
||||
current_used = get_used_seats(tenant_id)
|
||||
total_seats = metadata.seats
|
||||
|
||||
# Use > (not >=) to allow filling to exactly 100% capacity
|
||||
would_exceed_limit = current_used + seats_needed > total_seats
|
||||
if would_exceed_limit:
|
||||
return SeatAvailabilityResult(
|
||||
available=False,
|
||||
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
|
||||
f"cannot add {seats_needed} more user(s).",
|
||||
)
|
||||
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
from onyx.db.models import TokenRateLimit
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -17,15 +18,13 @@ from onyx.db.models import UserRole
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
|
||||
|
||||
def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select:
|
||||
if user.role == UserRole.ADMIN:
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
) -> Select:
|
||||
# If user is None and auth is disabled, assume the user is an admin
|
||||
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
|
||||
return stmt
|
||||
|
||||
# If anonymous user, only show global/public token_rate_limits
|
||||
if user.is_anonymous:
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
stmt = stmt.distinct()
|
||||
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
|
||||
User__UG = aliased(User__UserGroup)
|
||||
@@ -50,6 +49,11 @@ def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Se
|
||||
- if we are not editing, we show all token_rate_limits in the groups the user curates
|
||||
"""
|
||||
|
||||
# If user is None, this is an anonymous user and we should only show public token_rate_limits
|
||||
if user is None:
|
||||
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UG.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UG.is_curator == True # noqa: E712
|
||||
@@ -110,7 +114,7 @@ def insert_user_group_token_rate_limit(
|
||||
def fetch_user_group_token_rate_limits_for_user(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User,
|
||||
user: User | None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
|
||||
@@ -125,7 +125,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
object_is_public: bool | None = None,
|
||||
object_is_perm_sync: bool | None = None,
|
||||
@@ -144,8 +144,7 @@ def validate_object_creation_for_user(
|
||||
if object_is_perm_sync and not target_group_ids:
|
||||
return
|
||||
|
||||
# Admins are allowed
|
||||
if user.role == UserRole.ADMIN:
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# Allow curators and global curators to create public objects
|
||||
@@ -475,15 +474,14 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_making_change: User,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
# Admins can update curator relationships for any group
|
||||
if user_making_change.role == UserRole.ADMIN:
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
@@ -552,7 +550,7 @@ def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
@@ -601,7 +599,7 @@ def update_user_curator_relationship(
|
||||
|
||||
def add_users_to_user_group(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user: User | None,
|
||||
user_group_id: int,
|
||||
user_ids: list[UUID],
|
||||
) -> UserGroup:
|
||||
@@ -643,7 +641,7 @@ def add_users_to_user_group(
|
||||
|
||||
def update_user_group(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
user: User | None,
|
||||
user_group_id: int,
|
||||
user_group_update: UserGroupUpdate,
|
||||
) -> UserGroup:
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import Generator
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
@@ -28,7 +28,7 @@ def confluence_doc_sync(
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Fetches document permissions from Confluence and yields DocExternalAccess objects.
|
||||
Compares fetched documents against existing documents in the DB for the connector.
|
||||
|
||||
@@ -5,12 +5,8 @@ from datetime import timezone
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -42,12 +38,12 @@ def gmail_doc_sync(
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents and hierarchy nodes in postgres.
|
||||
If the document doesn't already exist in postgres, we create
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated.
|
||||
already populated
|
||||
"""
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -64,15 +60,6 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# Yield hierarchy node permissions to be processed in outer layer
|
||||
if slim_doc.external_access:
|
||||
yield NodeExternalAccess(
|
||||
external_access=slim_doc.external_access,
|
||||
raw_node_id=slim_doc.raw_node_id,
|
||||
source=DocumentSource.GMAIL.value,
|
||||
)
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,15 +10,11 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -171,101 +167,17 @@ def get_external_access_for_raw_gdrive_file(
|
||||
)
|
||||
|
||||
|
||||
def get_external_access_for_folder(
|
||||
folder: GoogleDriveFileType,
|
||||
google_domain: str,
|
||||
drive_service: GoogleDriveService,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Extract ExternalAccess from a folder's permissions.
|
||||
|
||||
This fetches permissions using the Drive API (via permissionIds) and extracts
|
||||
user emails, group emails, and public access status.
|
||||
|
||||
Args:
|
||||
folder: The folder metadata from Google Drive API (must include permissionIds field)
|
||||
google_domain: The company's Google Workspace domain (e.g., "company.com")
|
||||
drive_service: Google Drive service for fetching permission details
|
||||
|
||||
Returns:
|
||||
ExternalAccess with extracted permission info
|
||||
"""
|
||||
folder_id = folder.get("id")
|
||||
if not folder_id:
|
||||
logger.warning("Folder missing ID, returning empty permissions")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# Get permission IDs from folder metadata
|
||||
permission_ids = folder.get("permissionIds") or []
|
||||
if not permission_ids:
|
||||
logger.debug(f"No permissionIds found for folder {folder_id}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
# Fetch full permission objects using the permission IDs
|
||||
permissions_list = get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
doc_id=folder_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
is_public = False
|
||||
|
||||
for permission in permissions_list:
|
||||
if permission.type == PermissionType.USER:
|
||||
if permission.email_address:
|
||||
user_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.warning(f"User permission without email for folder {folder_id}")
|
||||
elif permission.type == PermissionType.GROUP:
|
||||
# Groups are represented as email addresses in Google Drive
|
||||
if permission.email_address:
|
||||
group_emails.add(permission.email_address)
|
||||
else:
|
||||
logger.warning(f"Group permission without email for folder {folder_id}")
|
||||
elif permission.type == PermissionType.DOMAIN:
|
||||
# Domain permission - check if it matches company domain
|
||||
if permission.domain == google_domain:
|
||||
# Only public if discoverable (allowFileDiscovery is not False)
|
||||
# If allowFileDiscovery is False, it's "link only" access
|
||||
is_public = permission.allow_file_discovery is not False
|
||||
else:
|
||||
logger.debug(
|
||||
f"Domain permission for {permission.domain} does not match "
|
||||
f"company domain {google_domain} for folder {folder_id}"
|
||||
)
|
||||
elif permission.type == PermissionType.ANYONE:
|
||||
# Only public if discoverable (allowFileDiscovery is not False)
|
||||
# If allowFileDiscovery is False, it's "link only" access
|
||||
is_public = permission.allow_file_discovery is not False
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_emails,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Adds the external permissions to the documents and hierarchy nodes in postgres.
|
||||
If the document doesn't already exist in postgres, we create
|
||||
Adds the external permissions to the documents in postgres
|
||||
if the document doesn't already exists in postgres, we create
|
||||
it in postgres so that when it gets created later, the permissions are
|
||||
already populated.
|
||||
already populated
|
||||
"""
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
@@ -283,15 +195,7 @@ def gdrive_doc_sync(
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# Yield hierarchy node permissions to be processed in outer layer
|
||||
if slim_doc.external_access:
|
||||
yield NodeExternalAccess(
|
||||
external_access=slim_doc.external_access,
|
||||
raw_node_id=slim_doc.raw_node_id,
|
||||
source=DocumentSource.GOOGLE_DRIVE.value,
|
||||
)
|
||||
continue
|
||||
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
|
||||
@@ -30,10 +30,6 @@ class GoogleDrivePermission(BaseModel):
|
||||
type: PermissionType
|
||||
domain: str | None # only applies to domain permissions
|
||||
permission_details: GoogleDrivePermissionDetails | None
|
||||
# Whether this permission makes the file discoverable in search
|
||||
# False means "anyone with the link" (not searchable/discoverable)
|
||||
# Only applicable for domain/anyone permission types
|
||||
allow_file_discovery: bool | None
|
||||
|
||||
@classmethod
|
||||
def from_drive_permission(
|
||||
@@ -50,7 +46,6 @@ class GoogleDrivePermission(BaseModel):
|
||||
email_address=drive_permission.get("emailAddress"),
|
||||
type=PermissionType(drive_permission["type"]),
|
||||
domain=drive_permission.get("domain"),
|
||||
allow_file_discovery=drive_permission.get("allowFileDiscovery"),
|
||||
permission_details=(
|
||||
GoogleDrivePermissionDetails(
|
||||
permission_type=permission_details.get("type"),
|
||||
|
||||
@@ -36,7 +36,7 @@ def get_permissions_by_ids(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails),nextPageToken",
|
||||
fields="permissions(id, emailAddress, type, domain, permissionDetails),nextPageToken",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -20,7 +20,7 @@ def jira_doc_sync(
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
jira_connector = JiraConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
|
||||
@@ -5,8 +5,6 @@ from typing import Protocol
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.access.models import ElementExternalAccess # noqa
|
||||
from onyx.access.models import NodeExternalAccess # noqa
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.db.utils import DocumentRow
|
||||
@@ -55,7 +53,7 @@ DocSyncFuncType = Callable[
|
||||
FetchAllDocumentsIdsFunction,
|
||||
Optional[IndexingHeartbeatInterface],
|
||||
],
|
||||
Generator[ElementExternalAccess, None, None],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
|
||||
@@ -34,21 +34,21 @@ def _get_all_censoring_enabled_sources() -> set[DocumentSource]:
|
||||
# NOTE: This is only called if ee is enabled.
|
||||
def _post_query_chunk_censoring(
|
||||
chunks: list[InferenceChunk],
|
||||
user: User,
|
||||
user: User | None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function checks all chunks to see if they need to be sent to a censoring
|
||||
function. If they do, it sends them to the censoring function and returns the
|
||||
censored chunks. If they don't, it returns the original chunks.
|
||||
"""
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
|
||||
# Anonymous users can only access public (non-permission-synced) content
|
||||
if user.is_anonymous:
|
||||
return [chunk for chunk in chunks if chunk.source_type not in sources_to_censor]
|
||||
if user is None:
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
final_chunk_dict: dict[str, InferenceChunk] = {}
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
for chunk in chunks:
|
||||
# Separate out chunks that require permission post-processing by source
|
||||
if chunk.source_type in sources_to_censor:
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -20,7 +20,7 @@ def sharepoint_doc_sync(
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
@@ -112,9 +111,6 @@ def _get_slack_document_access(
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if isinstance(doc_metadata, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.teams.connector import TeamsConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -21,7 +21,7 @@ def teams_doc_sync(
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
teams_connector = TeamsConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
|
||||
@@ -2,12 +2,9 @@ from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ElementExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.models import NodeExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -22,7 +19,7 @@ def generic_doc_sync(
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnectorWithPermSync,
|
||||
label: str,
|
||||
) -> Generator[ElementExternalAccess, None, None]:
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
A convenience function for performing a generic document synchronization.
|
||||
|
||||
@@ -32,7 +29,7 @@ def generic_doc_sync(
|
||||
- fetching *all* new (slim) docs
|
||||
- yielding external-access permissions for existing docs which do not exist in the newly fetched slim-docs set (with their
|
||||
`external_access` set to "private")
|
||||
- yielding external-access permissions for newly fetched docs and hierarchy nodes
|
||||
- yielding external-access permissions for newly fetched docs
|
||||
|
||||
Returns:
|
||||
A `Generator` which yields existing and newly fetched external-access permissions.
|
||||
@@ -52,15 +49,6 @@ def generic_doc_sync(
|
||||
callback.progress(label, 1)
|
||||
|
||||
for doc in doc_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
# Yield hierarchy node permissions to be processed in outer layer
|
||||
if doc.external_access:
|
||||
yield NodeExternalAccess(
|
||||
external_access=doc.external_access,
|
||||
raw_node_id=doc.raw_node_id,
|
||||
source=doc_source.value,
|
||||
)
|
||||
continue
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(
|
||||
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
|
||||
|
||||
@@ -4,10 +4,8 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
@@ -87,11 +85,10 @@ def get_application() -> FastAPI:
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
else:
|
||||
# License enforcement middleware for self-hosted deployments only
|
||||
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
|
||||
# MT deployments use control plane gating via is_tenant_gated() instead
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
@@ -151,13 +148,6 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
@@ -41,7 +41,7 @@ def _run_single_search(
|
||||
query: str,
|
||||
filters: BaseFilters | None,
|
||||
document_index: DocumentIndex,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -63,7 +63,7 @@ def _run_single_search(
|
||||
|
||||
def stream_search_query(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Generator[
|
||||
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
|
||||
@@ -101,7 +101,8 @@ def stream_search_query(
|
||||
# Build list of all executed queries for tracking
|
||||
all_executed_queries = [original_query] + keyword_expansions
|
||||
|
||||
if not user.is_anonymous:
|
||||
# TODO remove this check, user should not be None
|
||||
if user is not None:
|
||||
create_search_query(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
|
||||
@@ -40,7 +40,7 @@ class QueryAnalyticsResponse(BaseModel):
|
||||
def get_query_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryAnalyticsResponse]:
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
@@ -71,7 +71,7 @@ class UserAnalyticsResponse(BaseModel):
|
||||
def get_user_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserAnalyticsResponse]:
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
@@ -105,7 +105,7 @@ class OnyxbotAnalyticsResponse(BaseModel):
|
||||
def get_onyxbot_analytics(
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[OnyxbotAnalyticsResponse]:
|
||||
daily_onyxbot_info = fetch_onyxbot_analytics(
|
||||
@@ -141,7 +141,7 @@ def get_persona_messages(
|
||||
persona_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaMessageAnalyticsResponse]:
|
||||
"""Fetch daily message counts for a single persona within the given time range."""
|
||||
@@ -179,7 +179,7 @@ def get_persona_unique_users(
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaUniqueUsersResponse]:
|
||||
"""Get unique users per day for a single persona."""
|
||||
@@ -218,7 +218,7 @@ def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
|
||||
@@ -12,14 +12,6 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
("/admin/billing/stripe-publishable-key", {"GET"}),
|
||||
# Proxy endpoints use license-based auth, not user auth
|
||||
("/proxy/create-checkout-session", {"POST"}),
|
||||
("/proxy/claim-license", {"POST"}),
|
||||
("/proxy/create-customer-portal-session", {"POST"}),
|
||||
("/proxy/billing-information", {"GET"}),
|
||||
("/proxy/license/{tenant_id}", {"GET"}),
|
||||
("/proxy/seats/update", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,373 +0,0 @@
|
||||
"""Unified Billing API endpoints.
|
||||
|
||||
These endpoints provide Stripe billing functionality for both cloud and
|
||||
self-hosted deployments. The service layer routes requests appropriately:
|
||||
|
||||
- Self-hosted: Routes through cloud data plane proxy
|
||||
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Backend /admin/billing/* → Control plane
|
||||
|
||||
License claiming is handled separately by /license/claim endpoint (self-hosted only).
|
||||
|
||||
Migration Note (ENG-3533):
|
||||
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
|
||||
- /tenants/billing-information -> /admin/billing/billing-information
|
||||
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
|
||||
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
|
||||
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
|
||||
|
||||
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_customer_portal_session as create_portal_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/admin/billing")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
# Redis key for billing circuit breaker (self-hosted only)
|
||||
# When set, billing requests to Stripe are disabled until user manually retries
|
||||
BILLING_CIRCUIT_BREAKER_KEY = "billing_circuit_open"
|
||||
# Circuit breaker auto-expires after 1 hour (user can manually retry sooner)
|
||||
BILLING_CIRCUIT_BREAKER_TTL_SECONDS = 3600
|
||||
|
||||
|
||||
def _is_billing_circuit_open() -> bool:
|
||||
"""Check if the billing circuit breaker is open (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return False
|
||||
try:
|
||||
redis_client = get_shared_redis_client()
|
||||
is_open = bool(redis_client.exists(BILLING_CIRCUIT_BREAKER_KEY))
|
||||
logger.debug(
|
||||
f"Circuit breaker check: key={BILLING_CIRCUIT_BREAKER_KEY}, is_open={is_open}"
|
||||
)
|
||||
return is_open
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check circuit breaker: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _open_billing_circuit() -> None:
|
||||
"""Open the billing circuit breaker after a failure (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
try:
|
||||
redis_client = get_shared_redis_client()
|
||||
redis_client.set(
|
||||
BILLING_CIRCUIT_BREAKER_KEY,
|
||||
"1",
|
||||
ex=BILLING_CIRCUIT_BREAKER_TTL_SECONDS,
|
||||
)
|
||||
# Verify it was set
|
||||
exists = redis_client.exists(BILLING_CIRCUIT_BREAKER_KEY)
|
||||
logger.warning(
|
||||
f"Billing circuit breaker opened (TTL={BILLING_CIRCUIT_BREAKER_TTL_SECONDS}s, "
|
||||
f"verified={exists}). Stripe billing requests are disabled until manually reset."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to open circuit breaker: {e}")
|
||||
|
||||
|
||||
def _close_billing_circuit() -> None:
|
||||
"""Close the billing circuit breaker (re-enable Stripe requests)."""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
try:
|
||||
redis_client = get_shared_redis_client()
|
||||
redis_client.delete(BILLING_CIRCUIT_BREAKER_KEY)
|
||||
logger.info(
|
||||
"Billing circuit breaker closed. Stripe billing requests re-enabled."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close circuit breaker: {e}")
|
||||
|
||||
|
||||
def _get_license_data(db_session: Session) -> str | None:
|
||||
"""Get license data from database if exists (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return None
|
||||
license_record = get_license(db_session)
|
||||
return license_record.license_data if license_record else None
|
||||
|
||||
|
||||
def _get_tenant_id() -> str | None:
|
||||
"""Get tenant ID for cloud deployments."""
|
||||
if MULTI_TENANT:
|
||||
return get_current_tenant_id()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
|
||||
For new customers, no license/tenant is required.
|
||||
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
|
||||
|
||||
After checkout completion:
|
||||
- Self-hosted: Use /license/claim to retrieve the license
|
||||
- Cloud: Subscription is automatically activated
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
seats = request.seats if request else None
|
||||
email = request.email if request else None
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
|
||||
Requires existing license (self-hosted) or active tenant (cloud).
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
|
||||
Returns subscription status and details from Stripe.
|
||||
For self-hosted: If the circuit breaker is open (previous failure),
|
||||
returns a 503 error without making the request.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted without license = no subscription
|
||||
if not MULTI_TENANT and not license_data:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
return await get_billing_service(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
_open_billing_circuit()
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Handles Stripe proration and license regeneration via control plane.
|
||||
For self-hosted, the frontend should call /license/claim after a short delay
|
||||
to fetch the regenerated license.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""Fetch the Stripe publishable key.
|
||||
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
class ResetConnectionResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@router.post("/reset-connection")
|
||||
async def reset_stripe_connection(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> ResetConnectionResponse:
|
||||
"""Reset the Stripe connection circuit breaker.
|
||||
|
||||
Called when user clicks "Connect to Stripe" to retry after a previous failure.
|
||||
This clears the circuit breaker flag, allowing billing requests to proceed again.
|
||||
Self-hosted only - cloud deployments don't use the circuit breaker.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return ResetConnectionResponse(
|
||||
success=True,
|
||||
message="Circuit breaker not applicable for cloud deployments",
|
||||
)
|
||||
|
||||
_close_billing_circuit()
|
||||
return ResetConnectionResponse(
|
||||
success=True,
|
||||
message="Stripe connection reset. Billing requests re-enabled.",
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
"""Pydantic models for the billing API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
seats: int | None = None
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe checkout session URL."""
|
||||
|
||||
stripe_checkout_url: str
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe customer portal session."""
|
||||
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe customer portal URL."""
|
||||
|
||||
stripe_customer_portal_url: str
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
"""Billing information for the current subscription."""
|
||||
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: datetime | None = None
|
||||
current_period_end: datetime | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: datetime | None = None
|
||||
trial_start: datetime | None = None
|
||||
trial_end: datetime | None = None
|
||||
payment_method_enabled: bool = False
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
"""Response when no subscription exists."""
|
||||
|
||||
subscribed: bool = False
|
||||
|
||||
|
||||
class SeatUpdateRequest(BaseModel):
|
||||
"""Request to update seat count."""
|
||||
|
||||
new_seat_count: int
|
||||
|
||||
|
||||
class SeatUpdateResponse(BaseModel):
|
||||
"""Response from seat update operation."""
|
||||
|
||||
success: bool
|
||||
current_seats: int
|
||||
used_seats: int
|
||||
message: str | None = None
|
||||
license: str | None = None # Regenerated license (self-hosted stores this)
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
"""Response containing the Stripe publishable key."""
|
||||
|
||||
publishable_key: str
|
||||
@@ -1,273 +0,0 @@
|
||||
"""Service layer for billing operations.
|
||||
|
||||
This module provides functions for billing operations that route differently
|
||||
based on deployment type:
|
||||
|
||||
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
|
||||
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Cloud backend → Control plane
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# HTTP request timeout for billing service calls
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
Self-hosted instances authenticate with their license.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if license_data:
|
||||
headers["Authorization"] = f"Bearer {license_data}"
|
||||
return headers
|
||||
|
||||
|
||||
def _get_direct_headers() -> dict[str, str]:
|
||||
"""Build headers for direct control plane requests (cloud).
|
||||
|
||||
Cloud instances authenticate with JWT.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return CONTROL_PLANE_API_BASE_URL
|
||||
return f"{CLOUD_DATA_PLANE_URL}/proxy"
|
||||
|
||||
|
||||
def _get_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Get appropriate headers based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return _get_direct_headers()
|
||||
return _get_proxy_headers(license_data)
|
||||
|
||||
|
||||
async def _make_billing_request(
|
||||
method: Literal["GET", "POST"],
|
||||
path: str,
|
||||
license_data: str | None = None,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
error_message: str = "Billing service request failed",
|
||||
) -> dict:
|
||||
"""Make an HTTP request to the billing service.
|
||||
|
||||
Consolidates the common HTTP request pattern used by all billing operations.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET or POST)
|
||||
path: URL path (appended to base URL)
|
||||
license_data: License for authentication (self-hosted)
|
||||
body: Request body for POST requests
|
||||
params: Query parameters for GET requests
|
||||
error_message: Default error message if request fails
|
||||
|
||||
Returns:
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
url = f"{base_url}{path}"
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = error_message
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
billing_period: str = "monthly",
|
||||
seats: int | None = None,
|
||||
email: str | None = None,
|
||||
license_data: str | None = None,
|
||||
redirect_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session.
|
||||
|
||||
Args:
|
||||
billing_period: "monthly" or "annual"
|
||||
seats: Number of seats to purchase (optional, uses default if not provided)
|
||||
email: Customer email for new subscriptions
|
||||
license_data: Existing license for renewals (self-hosted)
|
||||
redirect_url: URL to redirect after successful checkout
|
||||
tenant_id: Tenant ID (cloud only, for renewals)
|
||||
|
||||
Returns:
|
||||
CreateCheckoutSessionResponse with checkout URL
|
||||
"""
|
||||
body: dict = {"billing_period": billing_period}
|
||||
if seats is not None:
|
||||
body["seats"] = seats
|
||||
if email:
|
||||
body["email"] = email
|
||||
if redirect_url:
|
||||
body["redirect_url"] = redirect_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-checkout-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create checkout session",
|
||||
)
|
||||
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
|
||||
|
||||
|
||||
async def create_customer_portal_session(
|
||||
license_data: str | None = None,
|
||||
return_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
return_url: URL to return to after portal session
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
CreateCustomerPortalSessionResponse with portal URL
|
||||
"""
|
||||
body: dict = {}
|
||||
if return_url:
|
||||
body["return_url"] = return_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-customer-portal-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create customer portal session",
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
|
||||
|
||||
|
||||
async def get_billing_information(
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Fetch billing information.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
BillingInformationResponse or SubscriptionStatusResponse if no subscription
|
||||
"""
|
||||
params = {}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="GET",
|
||||
path="/billing-information",
|
||||
license_data=license_data,
|
||||
params=params or None,
|
||||
error_message="Failed to fetch billing information",
|
||||
)
|
||||
|
||||
# Check if no subscription
|
||||
if isinstance(data, dict) and data.get("subscribed") is False:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
return BillingInformationResponse(**data)
|
||||
|
||||
|
||||
async def update_seat_count(
|
||||
new_seat_count: int,
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Args:
|
||||
new_seat_count: New number of seats
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
SeatUpdateResponse with updated seat information
|
||||
"""
|
||||
body: dict = {"new_seat_count": new_seat_count}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/seats/update",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to update seat count",
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=data.get("success", False),
|
||||
current_seats=data.get("current_seats", 0),
|
||||
used_seats=data.get("used_seats", 0),
|
||||
message=data.get("message"),
|
||||
license=data.get("license"),
|
||||
)
|
||||
@@ -115,7 +115,7 @@ async def refresh_access_token(
|
||||
|
||||
@admin_router.put("")
|
||||
def admin_ee_put_settings(
|
||||
settings: EnterpriseSettings, _: User = Depends(current_admin_user)
|
||||
settings: EnterpriseSettings, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
store_settings(settings)
|
||||
|
||||
@@ -134,7 +134,7 @@ def ee_fetch_settings() -> EnterpriseSettings:
|
||||
def put_logo(
|
||||
file: UploadFile,
|
||||
is_logotype: bool = False,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
upload_logo(file=file, is_logotype=is_logotype)
|
||||
|
||||
@@ -187,7 +187,7 @@ def fetch_logo(
|
||||
|
||||
@admin_router.put("/custom-analytics-script")
|
||||
def upload_custom_analytics_script(
|
||||
script_upload: AnalyticsScriptUpload, _: User = Depends(current_admin_user)
|
||||
script_upload: AnalyticsScriptUpload, _: User | None = Depends(current_admin_user)
|
||||
) -> None:
|
||||
try:
|
||||
store_analytics_script(script_upload)
|
||||
|
||||
@@ -1,14 +1,4 @@
|
||||
"""License API endpoints for self-hosted deployments.
|
||||
|
||||
These endpoints allow self-hosted Onyx instances to:
|
||||
1. Claim a license after Stripe checkout (via cloud data plane proxy)
|
||||
2. Upload a license file manually (for air-gapped deployments)
|
||||
3. View license status and seat usage
|
||||
4. Refresh/delete the local license
|
||||
|
||||
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
|
||||
Cloud licensing is managed via the control plane and gated_tenants Redis key.
|
||||
"""
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
@@ -19,9 +9,7 @@ from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
@@ -32,11 +20,13 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -89,103 +79,81 @@ async def get_seat_usage(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str | None = None,
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Claim a license from the control plane (self-hosted only).
|
||||
|
||||
Two modes:
|
||||
1. With session_id: After Stripe checkout, exchange session_id for license
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
|
||||
Use without session_id after:
|
||||
- Updating seats via the billing API
|
||||
- Returning from the Stripe customer portal
|
||||
- Any operation that regenerates the license on control plane
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
if session_id:
|
||||
# Claim license after checkout using session_id
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"session_id": session_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
response = requests.get(
|
||||
url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {license_row.license_data}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
license_data = data.get("license")
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in DB
|
||||
upsert_license(db_session, license_data)
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
logger.info(
|
||||
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
|
||||
)
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
detail = "Failed to claim license"
|
||||
try:
|
||||
error_data = e.response.json() if e.response is not None else {}
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@@ -196,36 +164,33 @@ async def upload_license(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually (self-hosted only).
|
||||
|
||||
Used for air-gapped deployments where the cloud data plane is not accessible.
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
@@ -240,10 +205,8 @@ async def refresh_license_cache_endpoint(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the local database.
|
||||
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
Does NOT fetch from control plane - use /claim for that.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
@@ -270,15 +233,9 @@ async def delete_license(
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
|
||||
@@ -27,7 +27,7 @@ router = APIRouter(prefix="/manage")
|
||||
def create_standard_answer(
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
standard_answer_model = insert_standard_answer(
|
||||
keyword=standard_answer_creation_request.keyword,
|
||||
@@ -43,7 +43,7 @@ def create_standard_answer(
|
||||
@router.get("/admin/standard-answer")
|
||||
def list_standard_answers(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[StandardAnswer]:
|
||||
standard_answer_models = fetch_standard_answers(db_session=db_session)
|
||||
return [
|
||||
@@ -57,7 +57,7 @@ def patch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
standard_answer_creation_request: StandardAnswerCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswer:
|
||||
existing_standard_answer = fetch_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -83,7 +83,7 @@ def patch_standard_answer(
|
||||
def delete_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
return remove_standard_answer(
|
||||
standard_answer_id=standard_answer_id,
|
||||
@@ -95,7 +95,7 @@ def delete_standard_answer(
|
||||
def create_standard_answer_category(
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category_model = insert_standard_answer_category(
|
||||
category_name=standard_answer_category_creation_request.name,
|
||||
@@ -107,7 +107,7 @@ def create_standard_answer_category(
|
||||
@router.get("/admin/standard-answer/category")
|
||||
def list_standard_answer_categories(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[StandardAnswerCategory]:
|
||||
standard_answer_category_models = fetch_standard_answer_categories(
|
||||
db_session=db_session
|
||||
@@ -123,7 +123,7 @@ def patch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> StandardAnswerCategory:
|
||||
existing_standard_answer_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=standard_answer_category_id,
|
||||
|
||||
@@ -1,42 +1,4 @@
|
||||
"""Middleware to enforce license status for SELF-HOSTED deployments only.
|
||||
|
||||
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
|
||||
Multi-tenant gating is handled separately by the control plane via the
|
||||
/tenants/product-gating endpoint and is_tenant_gated() checks.
|
||||
|
||||
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
|
||||
============================================================
|
||||
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
|
||||
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false (default):
|
||||
Middleware is disabled. EE features are controlled solely by
|
||||
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=true:
|
||||
Middleware actively enforces license status. EE features require
|
||||
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
|
||||
|
||||
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
|
||||
enforcement will be the only mechanism for gating EE features.
|
||||
|
||||
License Enforcement States (when enabled)
|
||||
=========================================
|
||||
For self-hosted deployments:
|
||||
|
||||
1. No license (never subscribed):
|
||||
- Allow community features (basic connectors, search, chat)
|
||||
- Block EE-only features (analytics, user groups, etc.)
|
||||
|
||||
2. GATED_ACCESS (fully expired):
|
||||
- Block all routes except billing/auth/license
|
||||
- User must renew subscription to continue
|
||||
|
||||
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
|
||||
- Full access to all EE features
|
||||
- Seat limits enforced
|
||||
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
|
||||
"""
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
@@ -47,30 +9,38 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
|
||||
from ee.onyx.configs.license_enforcement_config import (
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(
|
||||
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
|
||||
)
|
||||
|
||||
|
||||
def _is_ee_only_path(path: str) -> bool:
|
||||
"""Check if path requires EE license (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
@@ -96,84 +66,29 @@ def add_license_enforcement_middleware(
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
|
||||
# If no cached metadata, check database (cache may have been cleared)
|
||||
if not metadata:
|
||||
logger.debug(
|
||||
"[license_enforcement] No cached license, checking database..."
|
||||
)
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
if metadata:
|
||||
logger.info(
|
||||
"[license_enforcement] Loaded license from database"
|
||||
)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"[license_enforcement] Failed to check database for license: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# User HAS a license (current or expired)
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
# License fully expired - gate the user
|
||||
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
|
||||
# they don't block access
|
||||
is_gated = True
|
||||
else:
|
||||
# License is active - check seat limit
|
||||
# used_seats in cache is kept accurate via invalidation
|
||||
# when users are added/removed
|
||||
if metadata.used_seats > metadata.seats:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request: "
|
||||
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "seat_limit_exceeded",
|
||||
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
|
||||
"used_seats": metadata.used_seats,
|
||||
"seats": metadata.seats,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No license in cache OR database = never subscribed
|
||||
# Allow community features, but block EE-only features
|
||||
if _is_ee_only_path(path):
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking EE-only path (no license): {path}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "enterprise_license_required",
|
||||
"message": "This feature requires an Enterprise license. "
|
||||
"Please upgrade to access this functionality.",
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request (license expired): {path}"
|
||||
)
|
||||
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
|
||||
@@ -22,7 +22,7 @@ basic_router = APIRouter(prefix="/query")
|
||||
def get_standard_answer(
|
||||
request: StandardAnswerRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_user),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> StandardAnswerResponse:
|
||||
try:
|
||||
standard_answers = oneoff_standard_answers(
|
||||
|
||||
@@ -37,7 +37,8 @@ router = APIRouter(prefix="/search")
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
_: User = Depends(current_user),
|
||||
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
@@ -69,7 +70,7 @@ def search_flow_classification(
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
@@ -113,7 +114,7 @@ def handle_send_search_message(
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
@@ -145,6 +146,11 @@ def get_search_history(
|
||||
detail="filter_days must be greater than 0",
|
||||
)
|
||||
|
||||
# TODO(yuhong) remove this
|
||||
if user is None:
|
||||
# Return empty list for unauthenticated users
|
||||
return SearchHistoryResponse(search_queries=[])
|
||||
|
||||
search_queries = fetch_search_queries_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
|
||||
@@ -28,9 +28,9 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User) -> None:
|
||||
# Anonymous users are only rate limited by global settings
|
||||
if user.is_anonymous:
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
|
||||
@@ -153,7 +153,7 @@ def snapshot_from_chat_session(
|
||||
@router.get("/admin/chat-sessions")
|
||||
def admin_get_chat_sessions(
|
||||
user_id: UUID,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
@@ -196,7 +196,7 @@ def get_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None = None,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -234,7 +234,7 @@ def get_chat_session_history(
|
||||
@router.get("/admin/chat-session-history/{chat_session_id}")
|
||||
def get_chat_session_admin(
|
||||
chat_session_id: UUID,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -269,7 +269,7 @@ def get_chat_session_admin(
|
||||
|
||||
@router.get("/admin/query-history/list")
|
||||
def list_all_query_history_exports(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[QueryHistoryExport]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -297,7 +297,7 @@ def list_all_query_history_exports(
|
||||
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
@@ -344,7 +344,7 @@ def start_query_history_export(
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
@@ -378,7 +378,7 @@ def get_query_history_export_status(
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED])
|
||||
|
||||
@@ -58,7 +58,7 @@ def generate_report(
|
||||
@router.get("/admin/usage-report/{report_name}")
|
||||
def read_usage_report(
|
||||
report_name: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
try:
|
||||
@@ -82,7 +82,7 @@ def read_usage_report(
|
||||
|
||||
@router.get("/admin/usage-report")
|
||||
def fetch_usage_reports(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UsageReportMetadata]:
|
||||
try:
|
||||
|
||||
@@ -123,14 +123,9 @@ def _seed_llms(
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
|
||||
if len(seeded_providers[0].model_configurations) > 0:
|
||||
default_model = seeded_providers[0].model_configurations[0].name
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id,
|
||||
model_name=default_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
|
||||
@@ -12,51 +12,21 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
|
||||
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
|
||||
|
||||
|
||||
def check_ee_features_enabled() -> bool:
|
||||
"""EE version: checks if EE features should be available.
|
||||
|
||||
Returns True if:
|
||||
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
|
||||
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
|
||||
- Self-hosted with a valid (non-expired) license
|
||||
|
||||
Returns False if:
|
||||
- Self-hosted with no license (never subscribed)
|
||||
- Self-hosted with expired license
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
# License enforcement disabled - allow EE features (legacy behavior)
|
||||
return True
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
|
||||
return True
|
||||
|
||||
# Self-hosted with enforcement - check for valid license
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license for EE features: {e}")
|
||||
# Fail closed - if Redis is down, other things will break anyway
|
||||
return False
|
||||
|
||||
# No license or GATED_ACCESS - no EE features
|
||||
return False
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license indicates GATED_ACCESS (fully expired).
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
@@ -73,10 +43,11 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status == _BLOCKING_STATUS:
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
settings.application_status = metadata.status
|
||||
# No license = user hasn't purchased yet, allow access for upgrade flow
|
||||
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_p
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
@@ -28,7 +29,7 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -44,7 +45,7 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
@@ -71,6 +72,7 @@ async def set_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import APIRouter
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.proxy import router as proxy_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
@@ -23,4 +22,3 @@ router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
router.include_router(proxy_router)
|
||||
|
||||
@@ -1,21 +1,3 @@
|
||||
"""Billing API endpoints for cloud multi-tenant deployments.
|
||||
|
||||
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
|
||||
which provides a unified API for both self-hosted and cloud deployments.
|
||||
|
||||
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
|
||||
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
|
||||
Current endpoints to migrate:
|
||||
- GET /tenants/billing-information -> GET /admin/billing/information
|
||||
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
|
||||
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
|
||||
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
|
||||
|
||||
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
|
||||
and are NOT part of this migration - they stay here.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
@@ -108,7 +90,11 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
|
||||
@@ -300,12 +300,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -323,13 +323,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -358,13 +359,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -389,13 +391,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -427,11 +430,12 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -1,455 +0,0 @@
|
||||
"""Proxy endpoints for billing operations.
|
||||
|
||||
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
|
||||
for self-hosted instances to reach the control plane.
|
||||
|
||||
Flow:
|
||||
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
|
||||
|
||||
Self-hosted instances call these endpoints with their license in the Authorization
|
||||
header. The cloud data plane validates the license signature and forwards the
|
||||
request to the control plane using JWT authentication.
|
||||
|
||||
Auth levels by endpoint:
|
||||
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
|
||||
- /claim-license: Session ID based (one-time after Stripe payment)
|
||||
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
|
||||
- /billing-information: Valid license required
|
||||
- /license/{tenant_id}: Valid license required
|
||||
- /seats/update: Valid license required
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import is_license_valid
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/proxy")
|
||||
|
||||
|
||||
def _check_license_enforcement_enabled() -> None:
|
||||
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Proxy endpoints are only available on cloud data plane",
|
||||
)
|
||||
|
||||
|
||||
def _extract_license_from_header(
|
||||
authorization: str | None,
|
||||
required: bool = True,
|
||||
) -> str | None:
|
||||
"""Extract license data from Authorization header.
|
||||
|
||||
Self-hosted instances authenticate to these proxy endpoints by sending their
|
||||
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
|
||||
|
||||
We use the Bearer scheme (RFC 6750) because:
|
||||
1. It's the standard HTTP auth scheme for token-based authentication
|
||||
2. The license blob is cryptographically signed (RSA), so it's self-validating
|
||||
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
|
||||
|
||||
The license data is the base64-encoded signed blob that contains tenant_id,
|
||||
seats, expiration, etc. We verify the signature to authenticate the caller.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value (e.g., "Bearer <license>")
|
||||
required: If True, raise 401 when header is missing/invalid
|
||||
|
||||
Returns:
|
||||
License data string (base64-encoded), or None if not required and missing
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if required and header is missing/invalid
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
if required:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Missing or invalid authorization header"
|
||||
)
|
||||
return None
|
||||
|
||||
return authorization.split(" ", 1)[1]
|
||||
|
||||
|
||||
def verify_license_auth(
|
||||
license_data: str,
|
||||
allow_expired: bool = False,
|
||||
) -> LicensePayload:
|
||||
"""Verify license signature and optionally check expiry.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded signed license blob
|
||||
allow_expired: If True, accept expired licenses (for renewal flows)
|
||||
|
||||
Returns:
|
||||
LicensePayload if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If license is invalid or expired (when not allowed)
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
|
||||
|
||||
if not allow_expired and not is_license_valid(payload):
|
||||
raise HTTPException(status_code=401, detail="License has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require valid (non-expired) license.
|
||||
|
||||
Used for endpoints that require an active subscription.
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=False)
|
||||
|
||||
|
||||
async def get_license_payload_allow_expired(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require license with valid signature, expired OK.
|
||||
|
||||
Used for endpoints needed to fix payment issues (portal, renewal checkout).
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def get_optional_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload | None:
|
||||
"""Dependency: Optional license auth (for checkout - new customers have none).
|
||||
|
||||
Returns None if no license provided, otherwise validates and returns payload.
|
||||
Expired licenses are allowed for renewal flows.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
license_data = _extract_license_from_header(authorization, required=False)
|
||||
if license_data is None:
|
||||
return None
|
||||
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def forward_to_control_plane(
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Forward a request to the control plane with proper authentication."""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
detail = "Control plane request failed"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"Control plane returned {status_code}: {detail}")
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
seats: int | None = None
|
||||
email: str | None = None
|
||||
# Redirect URL after successful checkout - self-hosted passes their instance URL
|
||||
redirect_url: str | None = None
|
||||
# Cancel URL when user exits checkout - returns to upgrade page
|
||||
cancel_url: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def proxy_create_checkout_session(
|
||||
request_body: CreateCheckoutSessionRequest,
|
||||
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Proxy checkout session creation to control plane.
|
||||
|
||||
Auth: Optional license (new customers don't have one yet).
|
||||
If license provided, expired is OK (for renewals).
|
||||
"""
|
||||
# license_payload is None for new customers who don't have a license yet.
|
||||
# In that case, tenant_id is omitted from the request body and the control
|
||||
# plane will create a new tenant during checkout completion.
|
||||
tenant_id = license_payload.tenant_id if license_payload else None
|
||||
|
||||
body: dict = {
|
||||
"billing_period": request_body.billing_period,
|
||||
}
|
||||
if tenant_id:
|
||||
body["tenant_id"] = tenant_id
|
||||
if request_body.seats is not None:
|
||||
body["seats"] = request_body.seats
|
||||
if request_body.email:
|
||||
body["email"] = request_body.email
|
||||
if request_body.redirect_url:
|
||||
body["redirect_url"] = request_body.redirect_url
|
||||
if request_body.cancel_url:
|
||||
body["cancel_url"] = request_body.cancel_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-checkout-session", body=body
|
||||
)
|
||||
return CreateCheckoutSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class ClaimLicenseRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ClaimLicenseResponse(BaseModel):
|
||||
tenant_id: str
|
||||
license: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@router.post("/claim-license")
|
||||
async def proxy_claim_license(
|
||||
request_body: ClaimLicenseRequest,
|
||||
) -> ClaimLicenseResponse:
|
||||
"""Claim a license after successful Stripe checkout.
|
||||
|
||||
Auth: Session ID based (one-time use after payment).
|
||||
The control plane verifies the session_id is valid and unclaimed.
|
||||
|
||||
Returns the license to the caller. For self-hosted instances, they will
|
||||
store the license locally. The cloud DP doesn't need to store it.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/claim-license",
|
||||
body={"session_id": request_body.session_id},
|
||||
)
|
||||
|
||||
tenant_id = result.get("tenant_id")
|
||||
license_data = result.get("license")
|
||||
|
||||
if not tenant_id or not license_data:
|
||||
logger.error(f"Control plane returned incomplete claim response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
return ClaimLicenseResponse(
|
||||
tenant_id=tenant_id,
|
||||
license=license_data,
|
||||
message="License claimed successfully",
|
||||
)
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def proxy_create_customer_portal_session(
|
||||
request_body: CreateCustomerPortalSessionRequest | None = None,
|
||||
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Proxy customer portal session creation to control plane.
|
||||
|
||||
Auth: License required, expired OK (need portal to fix payment issues).
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
body: dict = {"tenant_id": tenant_id}
|
||||
if request_body and request_body.return_url:
|
||||
body["return_url"] = request_body.return_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-customer-portal-session", body=body
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: str | None = None
|
||||
current_period_end: str | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: str | None = None
|
||||
trial_start: str | None = None
|
||||
trial_end: str | None = None
|
||||
payment_method_enabled: bool = False
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def proxy_billing_information(
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> BillingInformationResponse:
|
||||
"""Proxy billing information request to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"GET", "/billing-information", params={"tenant_id": tenant_id}
|
||||
)
|
||||
# Add tenant_id from license if not in response (control plane may not include it)
|
||||
if "tenant_id" not in result:
|
||||
result["tenant_id"] = tenant_id
|
||||
return BillingInformationResponse(**result)
|
||||
|
||||
|
||||
class LicenseFetchResponse(BaseModel):
|
||||
license: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
@router.get("/license/{tenant_id}")
|
||||
async def proxy_license_fetch(
|
||||
tenant_id: str,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> LicenseFetchResponse:
|
||||
"""Proxy license fetch to control plane.
|
||||
|
||||
Auth: Valid license required.
|
||||
The tenant_id in path must match the authenticated tenant.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
if tenant_id != license_payload.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot fetch license for a different tenant",
|
||||
)
|
||||
|
||||
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
|
||||
|
||||
license_data = result.get("license")
|
||||
if not license_data:
|
||||
logger.error(f"Control plane returned incomplete license response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
# Return license to caller - self-hosted instance stores it via /api/license/claim
|
||||
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def proxy_seat_update(
|
||||
request_body: SeatUpdateRequest,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Proxy seat update to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
Handles Stripe proration and license regeneration.
|
||||
Returns the regenerated license in the response for the caller to store.
|
||||
"""
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/seats/update",
|
||||
body={
|
||||
"tenant_id": tenant_id,
|
||||
"new_seat_count": request_body.new_seat_count,
|
||||
},
|
||||
)
|
||||
|
||||
# Return license in response - self-hosted instance stores it via /api/license/claim
|
||||
return SeatUpdateResponse(
|
||||
success=result.get("success", False),
|
||||
current_seats=result.get("current_seats", 0),
|
||||
used_seats=result.get("used_seats", 0),
|
||||
message=result.get("message"),
|
||||
license=result.get("license"),
|
||||
)
|
||||
@@ -24,12 +24,12 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User = Depends(current_admin_user),
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user.email != user_email.user_email:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
@@ -26,8 +26,10 @@ FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
return None
|
||||
|
||||
@@ -24,8 +24,10 @@ router = APIRouter(prefix="/tenants")
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
@@ -37,7 +39,7 @@ async def request_invite(
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
@@ -46,7 +48,7 @@ def list_pending_users(
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
@@ -55,11 +57,14 @@ async def approve_user(
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
@@ -70,11 +75,14 @@ async def accept_invite(
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
|
||||
@@ -70,60 +70,45 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
|
||||
If a user already has an active mapping to a different tenant, they receive
|
||||
an inactive mapping (invitation) to this tenant. They can accept the
|
||||
invitation later to switch tenants.
|
||||
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
unique_emails = set(emails)
|
||||
if not unique_emails:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
# Batch query 1: Get all existing mappings for these emails to this tenant
|
||||
# Lock rows to prevent concurrent modifications
|
||||
existing_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
emails_with_mapping = {m.email for m in existing_mappings}
|
||||
|
||||
# Batch query 2: Get all active mappings for these emails (any tenant)
|
||||
active_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails_with_active_mapping = {m.email for m in active_mappings}
|
||||
|
||||
# Add mappings for emails that don't already have one to this tenant
|
||||
for email in unique_emails:
|
||||
if email in emails_with_mapping:
|
||||
continue
|
||||
|
||||
# Create mapping: inactive if user belongs to another tenant (invitation),
|
||||
# active otherwise
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=email not in emails_with_active_mapping,
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
@@ -213,15 +198,13 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# Lock the user's mappings first to prevent race conditions.
|
||||
# This ensures no concurrent request can modify this user's mappings.
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -314,41 +297,16 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant.
|
||||
|
||||
A user counts toward the seat count if:
|
||||
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
|
||||
2. AND the User is active (User.is_active == True)
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
from onyx.db.models import User
|
||||
|
||||
# First get all emails with active mappings to this tenant
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
active_mapping_emails = (
|
||||
db_session.query(UserTenantMapping.email)
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails = [email for (email,) in active_mapping_emails]
|
||||
|
||||
if not emails:
|
||||
return 0
|
||||
|
||||
# Now count how many of those users are actually active in the tenant's User table
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ Group Token Limit Settings
|
||||
|
||||
@router.get("/user-groups")
|
||||
def get_all_group_token_limit_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, list[TokenRateLimitDisplay]]:
|
||||
user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group(
|
||||
@@ -47,7 +47,7 @@ def get_all_group_token_limit_settings(
|
||||
@router.get("/user-group/{group_id}")
|
||||
def get_group_token_limit_settings(
|
||||
group_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
@@ -64,7 +64,7 @@ def get_group_token_limit_settings(
|
||||
def create_group_token_limit_settings(
|
||||
group_id: int,
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
@@ -86,7 +86,7 @@ User Token Limit Settings
|
||||
|
||||
@router.get("/users")
|
||||
def get_user_token_limit_settings(
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TokenRateLimitDisplay]:
|
||||
return [
|
||||
@@ -98,7 +98,7 @@ def get_user_token_limit_settings(
|
||||
@router.post("/users")
|
||||
def create_user_token_limit_settings(
|
||||
token_limit_settings: TokenRateLimitArgs,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TokenRateLimitDisplay:
|
||||
rate_limit_display = TokenRateLimitDisplay.from_db(
|
||||
|
||||
@@ -31,10 +31,10 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -48,7 +48,7 @@ def list_user_groups(
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -66,7 +66,7 @@ def create_user_group(
|
||||
def patch_user_group(
|
||||
user_group_id: int,
|
||||
user_group_update: UserGroupUpdate,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -86,7 +86,7 @@ def patch_user_group(
|
||||
def add_users(
|
||||
user_group_id: int,
|
||||
add_users_request: AddUsersToUserGroupRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
try:
|
||||
@@ -106,7 +106,7 @@ def add_users(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -124,7 +124,7 @@ def set_user_curator(
|
||||
@router.delete("/admin/user-group/{user_group_id}")
|
||||
def delete_user_group(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,6 @@ from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -21,7 +20,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
@@ -34,7 +33,7 @@ if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
BIN
backend/hello-vmlinux.bin
Normal file
BIN
backend/hello-vmlinux.bin
Normal file
Binary file not shown.
@@ -96,20 +96,22 @@ def get_access_for_documents(
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
Anonymous users only have access to public documents.
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user.is_anonymous:
|
||||
return {PUBLIC_DOC_PAT}
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
def get_acl_for_user(user: User, db_session: Session | None = None) -> set[str]:
|
||||
def get_acl_for_user(user: User | None, db_session: Session | None = None) -> set[str]:
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_user_external_group_ids(db_session: Session, user: User) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def get_user_external_group_ids(db_session: Session, user: User) -> list[str]:
|
||||
versioned_get_user_external_group_ids = fetch_versioned_implementation(
|
||||
"onyx.access.hierarchy_access", "_get_user_external_group_ids"
|
||||
)
|
||||
return versioned_get_user_external_group_ids(db_session, user)
|
||||
@@ -105,54 +105,6 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NodeExternalAccess:
|
||||
"""
|
||||
Wraps external access with a hierarchy node's raw ID.
|
||||
Used for syncing hierarchy node permissions (e.g., folder permissions).
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The raw node ID from the source system (e.g., Google Drive folder ID)
|
||||
raw_node_id: str
|
||||
# The source type (e.g., "google_drive")
|
||||
source: str
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"external_access": {
|
||||
"external_user_emails": list(self.external_access.external_user_emails),
|
||||
"external_user_group_ids": list(
|
||||
self.external_access.external_user_group_ids
|
||||
),
|
||||
"is_public": self.external_access.is_public,
|
||||
},
|
||||
"raw_node_id": self.raw_node_id,
|
||||
"source": self.source,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "NodeExternalAccess":
|
||||
external_access = ExternalAccess(
|
||||
external_user_emails=set(
|
||||
data["external_access"].get("external_user_emails", [])
|
||||
),
|
||||
external_user_group_ids=set(
|
||||
data["external_access"].get("external_user_group_ids", [])
|
||||
),
|
||||
is_public=data["external_access"]["is_public"],
|
||||
)
|
||||
return cls(
|
||||
external_access=external_access,
|
||||
raw_node_id=data["raw_node_id"],
|
||||
source=data["source"],
|
||||
)
|
||||
|
||||
|
||||
# Union type for elements that can have permissions synced
|
||||
ElementExternalAccess = DocExternalAccess | NodeExternalAccess
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
|
||||
@@ -56,7 +56,6 @@ class DisposableEmailValidator:
|
||||
"guerrillamail.com",
|
||||
"mailinator.com",
|
||||
"tempmail.com",
|
||||
"chat-tempmail.com",
|
||||
"throwaway.email",
|
||||
"yopmail.com",
|
||||
"temp-mail.org",
|
||||
|
||||
@@ -3,10 +3,10 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import ANONYMOUS_USER_INFO_ID
|
||||
from onyx.configs.constants import KV_ANONYMOUS_USER_PERSONALIZATION_KEY
|
||||
from onyx.configs.constants import KV_ANONYMOUS_USER_PREFERENCES_KEY
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PERSONALIZATION_KEY
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from onyx.configs.constants import NO_AUTH_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.key_value_store.store import KeyValueStore
|
||||
from onyx.key_value_store.store import KvKeyNotFoundError
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -14,22 +14,22 @@ from onyx.server.manage.models import UserPersonalization
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
def set_anonymous_user_preferences(
|
||||
def set_no_auth_user_preferences(
|
||||
store: KeyValueStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_ANONYMOUS_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def set_anonymous_user_personalization(
|
||||
def set_no_auth_user_personalization(
|
||||
store: KeyValueStore, personalization: UserPersonalization
|
||||
) -> None:
|
||||
store.store(KV_ANONYMOUS_USER_PERSONALIZATION_KEY, personalization.model_dump())
|
||||
store.store(KV_NO_AUTH_USER_PERSONALIZATION_KEY, personalization.model_dump())
|
||||
|
||||
|
||||
def load_anonymous_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_ANONYMOUS_USER_PREFERENCES_KEY)
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
@@ -38,26 +38,27 @@ def load_anonymous_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_anonymous_user_info(store: KeyValueStore) -> UserInfo:
|
||||
"""Fetch a UserInfo object for anonymous users (used for API responses)."""
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
personalization = UserPersonalization()
|
||||
try:
|
||||
personalization_data = cast(
|
||||
Mapping[str, Any], store.load(KV_ANONYMOUS_USER_PERSONALIZATION_KEY)
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PERSONALIZATION_KEY)
|
||||
)
|
||||
personalization = UserPersonalization(**personalization_data)
|
||||
except KvKeyNotFoundError:
|
||||
pass
|
||||
|
||||
return UserInfo(
|
||||
id=ANONYMOUS_USER_INFO_ID,
|
||||
email=ANONYMOUS_USER_EMAIL,
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.LIMITED,
|
||||
preferences=load_anonymous_user_preferences(store),
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
personalization=personalization,
|
||||
is_anonymous_user=True,
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
@@ -75,6 +75,7 @@ from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from onyx.configs.app_configs import PASSWORD_MAX_LENGTH
|
||||
@@ -91,8 +92,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import ANONYMOUS_USER_UUID
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
@@ -135,8 +134,12 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
@@ -1328,14 +1331,6 @@ async def optional_user(
|
||||
user: User | None = Depends(optional_fastapi_current_user),
|
||||
) -> User | None:
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if (
|
||||
user is not None
|
||||
and user.is_anonymous
|
||||
and anonymous_user_enabled(tenant_id=tenant_id)
|
||||
):
|
||||
return get_anonymous_user()
|
||||
|
||||
if user := await _check_for_saml_and_jwt(request, user, async_db_session):
|
||||
# If user is already set, _check_for_saml_and_jwt returns the same user object
|
||||
return user
|
||||
@@ -1352,26 +1347,15 @@ async def optional_user(
|
||||
return user
|
||||
|
||||
|
||||
def get_anonymous_user() -> User:
|
||||
"""Create anonymous user object."""
|
||||
user = User(
|
||||
id=uuid.UUID(ANONYMOUS_USER_UUID),
|
||||
email=ANONYMOUS_USER_EMAIL,
|
||||
hashed_password="",
|
||||
is_active=True,
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
use_memories=False,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User:
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
@@ -1392,7 +1376,7 @@ async def double_check_user(
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return get_anonymous_user()
|
||||
return None
|
||||
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
@@ -1401,19 +1385,19 @@ async def double_check_user(
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User:
|
||||
) -> User | None:
|
||||
return await double_check_user(user, include_expired=True)
|
||||
|
||||
|
||||
async def current_limited_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User:
|
||||
) -> User | None:
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accessible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User:
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return await double_check_user(
|
||||
@@ -1423,8 +1407,10 @@ async def current_chat_accessible_user(
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User:
|
||||
) -> User | None:
|
||||
user = await double_check_user(user)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
if user.role == UserRole.LIMITED:
|
||||
raise BasicAuthenticationError(
|
||||
@@ -1434,8 +1420,16 @@ async def current_user(
|
||||
|
||||
|
||||
async def current_curator_or_admin_user(
|
||||
user: User = Depends(current_user),
|
||||
) -> User:
|
||||
user: User | None = Depends(current_user),
|
||||
) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
@@ -1445,8 +1439,11 @@ async def current_curator_or_admin_user(
|
||||
return user
|
||||
|
||||
|
||||
async def current_admin_user(user: User = Depends(current_user)) -> User:
|
||||
if user.role != UserRole.ADMIN:
|
||||
async def current_admin_user(user: User | None = Depends(current_user)) -> User | None:
|
||||
if DISABLE_AUTH:
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
@@ -1471,7 +1468,7 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
secret: SecretType, # type: ignore[valid-type]
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
@@ -1487,7 +1484,7 @@ def generate_csrf_token() -> str:
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
state_secret: SecretType,
|
||||
state_secret: SecretType, # type: ignore[valid-type]
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
@@ -1507,7 +1504,7 @@ def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
state_secret: SecretType,
|
||||
state_secret: SecretType, # type: ignore[valid-type]
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user