mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-28 21:25:44 +00:00
Compare commits
68 Commits
test-tests
...
v2.11.0-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a02ff9922 | ||
|
|
71b8746a34 | ||
|
|
7080b3d966 | ||
|
|
adc3c86b16 | ||
|
|
b110621b13 | ||
|
|
a2dc752d14 | ||
|
|
f7d47a6ca9 | ||
|
|
9cc71b71ee | ||
|
|
f2bafd113a | ||
|
|
bb00ebd4a8 | ||
|
|
fda04aa6d2 | ||
|
|
285aae6f2f | ||
|
|
b75b1019f3 | ||
|
|
bbba32b989 | ||
|
|
f06bf69956 | ||
|
|
7d4fe480cc | ||
|
|
7f5b512856 | ||
|
|
844a01f751 | ||
|
|
d64be385db | ||
|
|
d0518388d6 | ||
|
|
a7f6d5f535 | ||
|
|
059e2869e6 | ||
|
|
04d90fd496 | ||
|
|
7cd29f4892 | ||
|
|
c2b86efebf | ||
|
|
bc5835967e | ||
|
|
c2b11cae01 | ||
|
|
cf17ba6a1c | ||
|
|
b03634ecaa | ||
|
|
9a7e92464f | ||
|
|
09b2a69c82 | ||
|
|
c5c027c168 | ||
|
|
882163a4ea | ||
|
|
de83a9a6f0 | ||
|
|
f73ce0632f | ||
|
|
0b10b11af3 | ||
|
|
d9e3b657d0 | ||
|
|
f6e9928dc1 | ||
|
|
ca3179ad8d | ||
|
|
5529829ff5 | ||
|
|
bdc7f6c100 | ||
|
|
90f8656afa | ||
|
|
3c7d35a6e8 | ||
|
|
40d58a37e3 | ||
|
|
be3ecd9640 | ||
|
|
a6da511490 | ||
|
|
c7577ebe58 | ||
|
|
b87078a4f5 | ||
|
|
8a408e7023 | ||
|
|
4c7b73a355 | ||
|
|
8e9cb94d4f | ||
|
|
a21af4b906 | ||
|
|
7f0ce0531f | ||
|
|
b631bfa656 | ||
|
|
eca6b6bef2 | ||
|
|
51ef28305d | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -8,4 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
231
.github/workflows/deployment.yml
vendored
231
.github/workflows/deployment.yml
vendored
@@ -26,12 +26,14 @@ jobs:
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
build-backend: ${{ steps.check.outputs.build-backend }}
|
||||
build-backend-craft: ${{ steps.check.outputs.build-backend-craft }}
|
||||
build-model-server: ${{ steps.check.outputs.build-model-server }}
|
||||
is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }}
|
||||
is-stable: ${{ steps.check.outputs.is-stable }}
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-craft-latest: ${{ steps.check.outputs.is-craft-latest }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
@@ -54,15 +56,20 @@ jobs:
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_CRAFT_LATEST=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_BACKEND_CRAFT=false
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == craft-* ]]; then
|
||||
IS_CRAFT_LATEST=true
|
||||
fi
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
fi
|
||||
@@ -90,6 +97,12 @@ jobs:
|
||||
fi
|
||||
fi
|
||||
|
||||
# Craft-latest builds backend with Craft enabled
|
||||
if [[ "$IS_CRAFT_LATEST" == "true" ]]; then
|
||||
BUILD_BACKEND_CRAFT=true
|
||||
BUILD_BACKEND=false
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
@@ -113,12 +126,14 @@ jobs:
|
||||
echo "build-web=$BUILD_WEB"
|
||||
echo "build-web-cloud=$BUILD_WEB_CLOUD"
|
||||
echo "build-backend=$BUILD_BACKEND"
|
||||
echo "build-backend-craft=$BUILD_BACKEND_CRAFT"
|
||||
echo "build-model-server=$BUILD_MODEL_SERVER"
|
||||
echo "is-cloud-tag=$IS_CLOUD"
|
||||
echo "is-stable=$IS_STABLE"
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-craft-latest=$IS_CRAFT_LATEST"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
@@ -1003,6 +1018,217 @@ jobs:
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-backend-craft-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-backend-craft-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-craft-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-backend-craft-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend-craft:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-backend-craft-amd64
|
||||
- build-backend-craft-arm64
|
||||
if: needs.determine-builds.outputs.build-backend-craft == 'true'
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-backend-craft
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-backend
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=craft-latest
|
||||
# TODO: Consider aligning craft-latest tags with regular backend builds (e.g., latest, edge, beta)
|
||||
# to keep tagging strategy consistent across all backend images
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-craft-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-craft-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
run: |
|
||||
IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}"
|
||||
docker buildx imagetools create \
|
||||
$(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \
|
||||
$IMAGES
|
||||
|
||||
build-model-server-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
@@ -1466,10 +1692,13 @@ jobs:
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
- merge-backend
|
||||
- build-backend-craft-amd64
|
||||
- build-backend-craft-arm64
|
||||
- merge-backend-craft
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || (needs.determine-builds.outputs.build-backend-craft == 'true' && (needs.build-backend-craft-amd64.result == 'failure' || needs.build-backend-craft-arm64.result == 'failure' || needs.merge-backend-craft.result == 'failure')) || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
@@ -66,7 +66,8 @@ repos:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -149,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -397,7 +415,6 @@
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
@@ -428,7 +445,6 @@
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
@@ -577,6 +593,23 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Build Sandbox Templates",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "onyx.server.features.build.sandbox.build_templates",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
},
|
||||
"consoleTitle": "Build Sandbox Templates"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
@@ -587,6 +620,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -16,3 +16,8 @@ dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
# Craft: demo_data directory should be unzipped at container startup, not copied
|
||||
**/demo_data/
|
||||
# Craft: templates/outputs/venv is created at container startup
|
||||
**/templates/outputs/venv
|
||||
|
||||
@@ -7,6 +7,10 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Build argument for Craft support (disabled by default)
|
||||
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
|
||||
ARG ENABLE_CRAFT=false
|
||||
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true" \
|
||||
@@ -46,7 +50,23 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Conditionally install Node.js 20 for Craft (required for Next.js)
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing Node.js 20 for Craft support..." && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# Conditionally install opencode CLI for Craft agent functionality
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
# TODO: download a specific, versioned release of the opencode CLI
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing opencode CLI for Craft support..." && \
|
||||
curl -fsSL https://opencode.ai/install | bash; \
|
||||
fi
|
||||
ENV PATH="/root/.opencode/bin:${PATH}"
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -89,6 +109,12 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
RUN python -c "import tiktoken; \
|
||||
tiktoken.get_encoding('cl100k_base')"
|
||||
@@ -113,7 +139,15 @@ COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
# This pre-bakes demo data, Python venv, and npm dependencies into the image
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Running Craft template setup at build time..." && \
|
||||
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
|
||||
fi
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
"""single onyx craft migration
|
||||
|
||||
Consolidates all buildmode/onyx craft tables into a single migration.
|
||||
|
||||
Tables created:
|
||||
- build_session: User build sessions with status tracking
|
||||
- sandbox: User-owned containerized environments (one per user)
|
||||
- artifact: Build output files (web apps, documents, images)
|
||||
- snapshot: Sandbox filesystem snapshots
|
||||
- build_message: Conversation messages for build sessions
|
||||
|
||||
Existing table modified:
|
||||
- connector_credential_pair: Added processing_mode column
|
||||
|
||||
Revision ID: 2020d417ec84
|
||||
Revises: 41fa44bef321
|
||||
Create Date: 2026-01-26 14:43:54.641405
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2020d417ec84"
|
||||
down_revision = "41fa44bef321"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# ENUMS
|
||||
# ==========================================================================
|
||||
|
||||
# Build session status enum
|
||||
build_session_status_enum = sa.Enum(
|
||||
"active",
|
||||
"idle",
|
||||
name="buildsessionstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Sandbox status enum
|
||||
sandbox_status_enum = sa.Enum(
|
||||
"provisioning",
|
||||
"running",
|
||||
"idle",
|
||||
"sleeping",
|
||||
"terminated",
|
||||
"failed",
|
||||
name="sandboxstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Artifact type enum
|
||||
artifact_type_enum = sa.Enum(
|
||||
"web_app",
|
||||
"pptx",
|
||||
"docx",
|
||||
"markdown",
|
||||
"excel",
|
||||
"image",
|
||||
name="artifacttype",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_session",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
build_session_status_enum,
|
||||
nullable=False,
|
||||
server_default="active",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("nextjs_port", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_session_user_created",
|
||||
"build_session",
|
||||
["user_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_build_session_status",
|
||||
"build_session",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE (user-owned, one per user)
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"sandbox",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("container_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sandbox_status_enum,
|
||||
nullable=False,
|
||||
server_default="provisioning",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_sandbox_status",
|
||||
"sandbox",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_sandbox_container_id",
|
||||
"sandbox",
|
||||
["container_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"artifact",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", artifact_type_enum, nullable=False),
|
||||
sa.Column("path", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_artifact_session_created",
|
||||
"artifact",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_artifact_type",
|
||||
"artifact",
|
||||
["type"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"snapshot",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("storage_path", sa.String(), nullable=False),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_snapshot_session_created",
|
||||
"snapshot",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"turn_index",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"SYSTEM",
|
||||
"USER",
|
||||
"ASSISTANT",
|
||||
"DANSWER",
|
||||
name="messagetype",
|
||||
create_type=False,
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"message_metadata",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_message_session_turn",
|
||||
"build_message",
|
||||
["session_id", "turn_index", sa.text("created_at ASC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"processing_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="regular",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_column("connector_credential_pair", "processing_mode")
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_message_session_turn", table_name="build_message")
|
||||
op.drop_table("build_message")
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
|
||||
op.drop_table("snapshot")
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_artifact_type", table_name="artifact")
|
||||
op.drop_index("ix_artifact_session_created", table_name="artifact")
|
||||
op.drop_table("artifact")
|
||||
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
|
||||
op.drop_index("ix_sandbox_status", table_name="sandbox")
|
||||
op.drop_table("sandbox")
|
||||
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_session_status", table_name="build_session")
|
||||
op.drop_index("ix_build_session_user_created", table_name="build_session")
|
||||
op.drop_table("build_session")
|
||||
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""make processing mode default all caps
|
||||
|
||||
Revision ID: 72aa7de2e5cf
|
||||
Revises: 2020d417ec84
|
||||
Create Date: 2026-01-26 18:58:47.705253
|
||||
|
||||
This migration fixes the ProcessingMode enum value mismatch:
|
||||
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
|
||||
- The original migration stored lowercase VALUES ('regular', 'file_system')
|
||||
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
|
||||
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "72aa7de2e5cf"
|
||||
down_revision = "2020d417ec84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Convert existing lowercase values to uppercase to match enum member names
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
|
||||
"WHERE processing_mode = 'regular'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
|
||||
"WHERE processing_mode = 'file_system'"
|
||||
)
|
||||
|
||||
# Update the server default to use uppercase
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"processing_mode",
|
||||
server_default="REGULAR",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# State prior to this was broken, so we don't want to revert back to it
|
||||
pass
|
||||
349
backend/alembic/versions/81c22b1e2e78_hierarchy_nodes_v1.py
Normal file
349
backend/alembic/versions/81c22b1e2e78_hierarchy_nodes_v1.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""hierarchy_nodes_v1
|
||||
|
||||
Revision ID: 81c22b1e2e78
|
||||
Revises: 72aa7de2e5cf
|
||||
Create Date: 2026-01-13 18:10:01.021451
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "81c22b1e2e78"
|
||||
down_revision = "72aa7de2e5cf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Human-readable display names for each source
|
||||
SOURCE_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ingestion_api": "Ingestion API",
|
||||
"slack": "Slack",
|
||||
"web": "Web",
|
||||
"google_drive": "Google Drive",
|
||||
"gmail": "Gmail",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"gitbook": "GitBook",
|
||||
"gitlab": "GitLab",
|
||||
"guru": "Guru",
|
||||
"bookstack": "BookStack",
|
||||
"outline": "Outline",
|
||||
"confluence": "Confluence",
|
||||
"jira": "Jira",
|
||||
"slab": "Slab",
|
||||
"productboard": "Productboard",
|
||||
"file": "File",
|
||||
"coda": "Coda",
|
||||
"notion": "Notion",
|
||||
"zulip": "Zulip",
|
||||
"linear": "Linear",
|
||||
"hubspot": "HubSpot",
|
||||
"document360": "Document360",
|
||||
"gong": "Gong",
|
||||
"google_sites": "Google Sites",
|
||||
"zendesk": "Zendesk",
|
||||
"loopio": "Loopio",
|
||||
"dropbox": "Dropbox",
|
||||
"sharepoint": "SharePoint",
|
||||
"teams": "Teams",
|
||||
"salesforce": "Salesforce",
|
||||
"discourse": "Discourse",
|
||||
"axero": "Axero",
|
||||
"clickup": "ClickUp",
|
||||
"mediawiki": "MediaWiki",
|
||||
"wikipedia": "Wikipedia",
|
||||
"asana": "Asana",
|
||||
"s3": "S3",
|
||||
"r2": "R2",
|
||||
"google_cloud_storage": "Google Cloud Storage",
|
||||
"oci_storage": "OCI Storage",
|
||||
"xenforo": "XenForo",
|
||||
"not_applicable": "Not Applicable",
|
||||
"discord": "Discord",
|
||||
"freshdesk": "Freshdesk",
|
||||
"fireflies": "Fireflies",
|
||||
"egnyte": "Egnyte",
|
||||
"airtable": "Airtable",
|
||||
"highspot": "Highspot",
|
||||
"drupal_wiki": "Drupal Wiki",
|
||||
"imap": "IMAP",
|
||||
"bitbucket": "Bitbucket",
|
||||
"testrail": "TestRail",
|
||||
"mock_connector": "Mock Connector",
|
||||
"user_file": "User File",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create hierarchy_node table
|
||||
op.create_table(
|
||||
"hierarchy_node",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("raw_node_id", sa.String(), nullable=False),
|
||||
sa.Column("display_name", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("node_type", sa.String(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("parent_id", sa.Integer(), nullable=True),
|
||||
# Permission fields - same pattern as Document table
|
||||
sa.Column(
|
||||
"external_user_emails",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"external_user_group_ids",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
# When document is deleted, just unlink (node can exist without document)
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
|
||||
# When parent node is deleted, orphan children (cleanup via pruning)
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
|
||||
)
|
||||
|
||||
# Add partial unique index to ensure only one SOURCE-type node per source
|
||||
# This prevents duplicate source root nodes from being created
|
||||
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
|
||||
ON hierarchy_node (source)
|
||||
WHERE node_type = 'SOURCE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create hierarchy_fetch_attempt table
|
||||
op.create_table(
|
||||
"hierarchy_fetch_attempt",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("full_exception_trace", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created",
|
||||
"hierarchy_fetch_attempt",
|
||||
["time_created"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair",
|
||||
"hierarchy_fetch_attempt",
|
||||
["connector_credential_pair_id"],
|
||||
)
|
||||
|
||||
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
|
||||
# We insert these so every existing document can have a parent hierarchy node
|
||||
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
|
||||
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
|
||||
# SOURCE nodes are always public since they're just categorical roots.
|
||||
for source in DocumentSource:
|
||||
source_name = (
|
||||
source.name
|
||||
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
|
||||
source_value = source.value # e.g., 'google_drive' - the raw_node_id
|
||||
display_name = SOURCE_DISPLAY_NAMES.get(
|
||||
source_value, source_value.replace("_", " ").title()
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
|
||||
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
|
||||
ON CONFLICT (raw_node_id, source) DO NOTHING
|
||||
"""
|
||||
).bindparams(
|
||||
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
|
||||
display_name=display_name,
|
||||
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Add parent_hierarchy_node_id column to document table
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# When hierarchy node is deleted, just unlink the document (SET NULL)
|
||||
op.create_foreign_key(
|
||||
"fk_document_parent_hierarchy_node",
|
||||
"document",
|
||||
"hierarchy_node",
|
||||
["parent_hierarchy_node_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_document_parent_hierarchy_node_id",
|
||||
"document",
|
||||
["parent_hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
|
||||
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
|
||||
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
|
||||
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE document d
|
||||
SET parent_hierarchy_node_id = hn.id
|
||||
FROM (
|
||||
-- Get the source for each document (pick MIN connector_id for determinism)
|
||||
SELECT DISTINCT ON (dbcc.id)
|
||||
dbcc.id as doc_id,
|
||||
c.source as source
|
||||
FROM document_by_connector_credential_pair dbcc
|
||||
JOIN connector c ON dbcc.connector_id = c.id
|
||||
ORDER BY dbcc.id, dbcc.connector_id
|
||||
) doc_source
|
||||
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
|
||||
WHERE d.id = doc_source.doc_id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create the persona__hierarchy_node association table
|
||||
op.create_table(
|
||||
"persona__hierarchy_node",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups
|
||||
op.create_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
"persona__hierarchy_node",
|
||||
["hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# Create the persona__document association table for attaching individual
|
||||
# documents directly to assistants
|
||||
op.create_table(
|
||||
"persona__document",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "document_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups by document_id
|
||||
op.create_index(
|
||||
"ix_persona__document_document_id",
|
||||
"persona__document",
|
||||
["document_id"],
|
||||
)
|
||||
|
||||
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove last_time_hierarchy_fetch from connector_credential_pair
|
||||
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
|
||||
|
||||
# Drop persona__document table
|
||||
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
|
||||
op.drop_table("persona__document")
|
||||
|
||||
# Drop persona__hierarchy_node table
|
||||
op.drop_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
table_name="persona__hierarchy_node",
|
||||
)
|
||||
op.drop_table("persona__hierarchy_node")
|
||||
|
||||
# Remove parent_hierarchy_node_id from document
|
||||
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
|
||||
op.drop_constraint(
|
||||
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("document", "parent_hierarchy_node_id")
|
||||
|
||||
# Drop hierarchy_fetch_attempt table
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_table("hierarchy_fetch_attempt")
|
||||
|
||||
# Drop hierarchy_node table
|
||||
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
|
||||
op.drop_table("hierarchy_node")
|
||||
@@ -122,6 +122,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
@@ -133,3 +136,9 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
# Used when MULTI_TENANT=false (self-hosted mode)
|
||||
CLOUD_DATA_PLANE_URL = os.environ.get(
|
||||
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
|
||||
)
|
||||
|
||||
73
backend/ee/onyx/configs/license_enforcement_config.py
Normal file
73
backend/ee/onyx/configs/license_enforcement_config.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Constants for license enforcement.
|
||||
|
||||
This file is the single source of truth for:
|
||||
1. Paths that bypass license enforcement (always accessible)
|
||||
2. Paths that require an EE license (EE-only features)
|
||||
|
||||
Import these constants in both production code and tests to ensure consistency.
|
||||
"""
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /billing - Unified billing API
|
||||
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
|
||||
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
|
||||
# /manage/users, /users - User management (needed for seat limit resolution)
|
||||
# /notifications - Needed for UI to load properly
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
# Billing endpoints (unified API for both MT and self-hosted)
|
||||
"/billing",
|
||||
"/admin/billing",
|
||||
# Proxy endpoints for self-hosted billing (no tenant context)
|
||||
"/proxy",
|
||||
# Legacy tenant billing endpoints (kept for backwards compatibility)
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
# User management - needed to remove users when seat limit exceeded
|
||||
"/manage/users",
|
||||
"/manage/admin/users",
|
||||
"/manage/admin/valid-domains",
|
||||
"/manage/admin/deactivate-user",
|
||||
"/manage/admin/delete-user",
|
||||
"/users",
|
||||
# Notifications - needed for UI to load properly
|
||||
"/notifications",
|
||||
}
|
||||
)
|
||||
|
||||
# EE-only paths that require a valid license.
|
||||
# Users without a license (community edition) cannot access these.
|
||||
# These are blocked even when user has never subscribed (no license).
|
||||
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
# User groups and access control
|
||||
"/manage/admin/user-group",
|
||||
# Analytics and reporting
|
||||
"/analytics",
|
||||
# Query history (admin chat session endpoints)
|
||||
"/admin/chat-sessions",
|
||||
"/admin/chat-session-history",
|
||||
"/admin/query-history",
|
||||
# Usage reporting/export
|
||||
"/admin/usage-report",
|
||||
# Standard answers (canned responses)
|
||||
"/manage/admin/standard-answer",
|
||||
# Token rate limits
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
}
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import NamedTuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -23,6 +25,13 @@ LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
class SeatAvailabilityResult(NamedTuple):
|
||||
"""Result of a seat availability check."""
|
||||
|
||||
available: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -95,23 +104,30 @@ def delete_license(db_session: Session) -> bool:
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -276,3 +292,43 @@ def get_license_metadata(
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
|
||||
|
||||
def check_seat_availability(
|
||||
db_session: Session,
|
||||
seats_needed: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatAvailabilityResult:
|
||||
"""
|
||||
Check if there are enough seats available to add users.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
seats_needed: Number of seats needed (default 1)
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
SeatAvailabilityResult with available=True if seats are available,
|
||||
or available=False with error_message if limit would be exceeded.
|
||||
Returns available=True if no license exists (self-hosted = unlimited).
|
||||
"""
|
||||
metadata = get_license_metadata(db_session, tenant_id)
|
||||
|
||||
# No license = no enforcement (self-hosted without license)
|
||||
if metadata is None:
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
# Calculate current usage directly from DB (not cache) for accuracy
|
||||
current_used = get_used_seats(tenant_id)
|
||||
total_seats = metadata.seats
|
||||
|
||||
# Use > (not >=) to allow filling to exactly 100% capacity
|
||||
would_exceed_limit = current_used + seats_needed > total_seats
|
||||
if would_exceed_limit:
|
||||
return SeatAvailabilityResult(
|
||||
available=False,
|
||||
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
|
||||
f"cannot add {seats_needed} more user(s).",
|
||||
)
|
||||
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -60,6 +61,9 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -195,7 +196,9 @@ def gdrive_doc_sync(
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
@@ -111,6 +112,9 @@ def _get_slack_document_access(
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if isinstance(doc_metadata, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
|
||||
@@ -5,6 +5,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -49,6 +50,9 @@ def generic_doc_sync(
|
||||
callback.progress(label, 1)
|
||||
|
||||
for doc in doc_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(
|
||||
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
|
||||
|
||||
@@ -4,8 +4,10 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
@@ -85,10 +87,11 @@ def get_application() -> FastAPI:
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
else:
|
||||
# License enforcement middleware for self-hosted deployments only
|
||||
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
|
||||
# MT deployments use control plane gating via is_tenant_gated() instead
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
@@ -148,6 +151,13 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
@@ -12,6 +12,14 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
("/admin/billing/stripe-publishable-key", {"GET"}),
|
||||
# Proxy endpoints use license-based auth, not user auth
|
||||
("/proxy/create-checkout-session", {"POST"}),
|
||||
("/proxy/claim-license", {"POST"}),
|
||||
("/proxy/create-customer-portal-session", {"POST"}),
|
||||
("/proxy/billing-information", {"GET"}),
|
||||
("/proxy/license/{tenant_id}", {"GET"}),
|
||||
("/proxy/seats/update", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
0
backend/ee/onyx/server/billing/__init__.py
Normal file
0
backend/ee/onyx/server/billing/__init__.py
Normal file
264
backend/ee/onyx/server/billing/api.py
Normal file
264
backend/ee/onyx/server/billing/api.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Unified Billing API endpoints.
|
||||
|
||||
These endpoints provide Stripe billing functionality for both cloud and
|
||||
self-hosted deployments. The service layer routes requests appropriately:
|
||||
|
||||
- Self-hosted: Routes through cloud data plane proxy
|
||||
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Backend /admin/billing/* → Control plane
|
||||
|
||||
License claiming is handled separately by /license/claim endpoint (self-hosted only).
|
||||
|
||||
Migration Note (ENG-3533):
|
||||
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
|
||||
- /tenants/billing-information -> /admin/billing/billing-information
|
||||
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
|
||||
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
|
||||
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
|
||||
|
||||
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_customer_portal_session as create_portal_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/admin/billing")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _get_license_data(db_session: Session) -> str | None:
|
||||
"""Get license data from database if exists (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return None
|
||||
license_record = get_license(db_session)
|
||||
return license_record.license_data if license_record else None
|
||||
|
||||
|
||||
def _get_tenant_id() -> str | None:
|
||||
"""Get tenant ID for cloud deployments."""
|
||||
if MULTI_TENANT:
|
||||
return get_current_tenant_id()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
|
||||
For new customers, no license/tenant is required.
|
||||
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
|
||||
|
||||
After checkout completion:
|
||||
- Self-hosted: Use /license/claim to retrieve the license
|
||||
- Cloud: Subscription is automatically activated
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
email = request.email if request else None
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
|
||||
Requires existing license (self-hosted) or active tenant (cloud).
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
|
||||
Returns subscription status and details from Stripe.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted without license = no subscription
|
||||
if not MULTI_TENANT and not license_data:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
try:
|
||||
return await get_billing_service(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Handles Stripe proration and license regeneration via control plane.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
try:
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""Fetch the Stripe publishable key.
|
||||
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
75
backend/ee/onyx/server/billing/models.py
Normal file
75
backend/ee/onyx/server/billing/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Pydantic models for the billing API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe checkout session URL."""
|
||||
|
||||
stripe_checkout_url: str
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe customer portal session."""
|
||||
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe customer portal URL."""
|
||||
|
||||
stripe_customer_portal_url: str
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
"""Billing information for the current subscription."""
|
||||
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: datetime | None = None
|
||||
current_period_end: datetime | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: datetime | None = None
|
||||
trial_start: datetime | None = None
|
||||
trial_end: datetime | None = None
|
||||
payment_method_enabled: bool = False
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
"""Response when no subscription exists."""
|
||||
|
||||
subscribed: bool = False
|
||||
|
||||
|
||||
class SeatUpdateRequest(BaseModel):
|
||||
"""Request to update seat count."""
|
||||
|
||||
new_seat_count: int
|
||||
|
||||
|
||||
class SeatUpdateResponse(BaseModel):
|
||||
"""Response from seat update operation."""
|
||||
|
||||
success: bool
|
||||
current_seats: int
|
||||
used_seats: int
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
"""Response containing the Stripe publishable key."""
|
||||
|
||||
publishable_key: str
|
||||
267
backend/ee/onyx/server/billing/service.py
Normal file
267
backend/ee/onyx/server/billing/service.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Service layer for billing operations.
|
||||
|
||||
This module provides functions for billing operations that route differently
|
||||
based on deployment type:
|
||||
|
||||
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
|
||||
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Cloud backend → Control plane
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# HTTP request timeout for billing service calls
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
Self-hosted instances authenticate with their license.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if license_data:
|
||||
headers["Authorization"] = f"Bearer {license_data}"
|
||||
return headers
|
||||
|
||||
|
||||
def _get_direct_headers() -> dict[str, str]:
|
||||
"""Build headers for direct control plane requests (cloud).
|
||||
|
||||
Cloud instances authenticate with JWT.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return CONTROL_PLANE_API_BASE_URL
|
||||
return f"{CLOUD_DATA_PLANE_URL}/proxy"
|
||||
|
||||
|
||||
def _get_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Get appropriate headers based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return _get_direct_headers()
|
||||
return _get_proxy_headers(license_data)
|
||||
|
||||
|
||||
async def _make_billing_request(
|
||||
method: Literal["GET", "POST"],
|
||||
path: str,
|
||||
license_data: str | None = None,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
error_message: str = "Billing service request failed",
|
||||
) -> dict:
|
||||
"""Make an HTTP request to the billing service.
|
||||
|
||||
Consolidates the common HTTP request pattern used by all billing operations.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET or POST)
|
||||
path: URL path (appended to base URL)
|
||||
license_data: License for authentication (self-hosted)
|
||||
body: Request body for POST requests
|
||||
params: Query parameters for GET requests
|
||||
error_message: Default error message if request fails
|
||||
|
||||
Returns:
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
base_url = _get_base_url()
|
||||
url = f"{base_url}{path}"
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = error_message
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
billing_period: str = "monthly",
|
||||
email: str | None = None,
|
||||
license_data: str | None = None,
|
||||
redirect_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session.
|
||||
|
||||
Args:
|
||||
billing_period: "monthly" or "annual"
|
||||
email: Customer email for new subscriptions
|
||||
license_data: Existing license for renewals (self-hosted)
|
||||
redirect_url: URL to redirect after successful checkout
|
||||
tenant_id: Tenant ID (cloud only, for renewals)
|
||||
|
||||
Returns:
|
||||
CreateCheckoutSessionResponse with checkout URL
|
||||
"""
|
||||
body: dict = {"billing_period": billing_period}
|
||||
if email:
|
||||
body["email"] = email
|
||||
if redirect_url:
|
||||
body["redirect_url"] = redirect_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-checkout-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create checkout session",
|
||||
)
|
||||
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
|
||||
|
||||
|
||||
async def create_customer_portal_session(
|
||||
license_data: str | None = None,
|
||||
return_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
return_url: URL to return to after portal session
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
CreateCustomerPortalSessionResponse with portal URL
|
||||
"""
|
||||
body: dict = {}
|
||||
if return_url:
|
||||
body["return_url"] = return_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-customer-portal-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create customer portal session",
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
|
||||
|
||||
|
||||
async def get_billing_information(
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Fetch billing information.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
BillingInformationResponse or SubscriptionStatusResponse if no subscription
|
||||
"""
|
||||
params = {}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="GET",
|
||||
path="/billing-information",
|
||||
license_data=license_data,
|
||||
params=params or None,
|
||||
error_message="Failed to fetch billing information",
|
||||
)
|
||||
|
||||
# Check if no subscription
|
||||
if isinstance(data, dict) and data.get("subscribed") is False:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
return BillingInformationResponse(**data)
|
||||
|
||||
|
||||
async def update_seat_count(
|
||||
new_seat_count: int,
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Args:
|
||||
new_seat_count: New number of seats
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
SeatUpdateResponse with updated seat information
|
||||
"""
|
||||
body: dict = {"new_seat_count": new_seat_count}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/seats/update",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to update seat count",
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=data.get("success", False),
|
||||
current_seats=data.get("current_seats", 0),
|
||||
used_seats=data.get("used_seats", 0),
|
||||
message=data.get("message"),
|
||||
)
|
||||
@@ -1,4 +1,14 @@
|
||||
"""License API endpoints."""
|
||||
"""License API endpoints for self-hosted deployments.
|
||||
|
||||
These endpoints allow self-hosted Onyx instances to:
|
||||
1. Claim a license after Stripe checkout (via cloud data plane proxy)
|
||||
2. Upload a license file manually (for air-gapped deployments)
|
||||
3. View license status and seat usage
|
||||
4. Refresh/delete the local license
|
||||
|
||||
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
|
||||
Cloud licensing is managed via the control plane and gated_tenants Redis key.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
@@ -9,6 +19,7 @@ from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
@@ -20,13 +31,11 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -79,81 +88,80 @@ async def get_seat_usage(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
Claim a license after Stripe checkout (self-hosted only).
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
After a user completes Stripe checkout, they're redirected back with a
|
||||
session_id. This endpoint exchanges that session_id for a signed license
|
||||
via the cloud data plane proxy.
|
||||
|
||||
Flow:
|
||||
1. Self-hosted frontend redirects to Stripe checkout (via cloud proxy)
|
||||
2. User completes payment
|
||||
3. Stripe redirects back to self-hosted instance with session_id
|
||||
4. Frontend calls this endpoint with session_id
|
||||
5. We call cloud data plane /proxy/claim-license to get the signed license
|
||||
6. License is stored locally and cached
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
# Call cloud data plane to claim the license
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"session_id": session_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
license_data = data.get("license")
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
# Store in DB
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
logger.info(
|
||||
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
|
||||
)
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
detail = "Failed to claim license"
|
||||
try:
|
||||
error_data = e.response.json() if e.response is not None else {}
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -164,33 +172,36 @@ async def upload_license(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
Upload a license file manually (self-hosted only).
|
||||
|
||||
Used for air-gapped deployments where the cloud data plane is not accessible.
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
@@ -205,8 +216,10 @@ async def refresh_license_cache_endpoint(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Force refresh the license cache from the local database.
|
||||
|
||||
Useful after manual database changes or to verify license validity.
|
||||
Does NOT fetch from control plane - use /claim for that.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
@@ -233,9 +246,15 @@ async def delete_license(
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
|
||||
@@ -1,4 +1,42 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
"""Middleware to enforce license status for SELF-HOSTED deployments only.
|
||||
|
||||
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
|
||||
Multi-tenant gating is handled separately by the control plane via the
|
||||
/tenants/product-gating endpoint and is_tenant_gated() checks.
|
||||
|
||||
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
|
||||
============================================================
|
||||
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
|
||||
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false (default):
|
||||
Middleware is disabled. EE features are controlled solely by
|
||||
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=true:
|
||||
Middleware actively enforces license status. EE features require
|
||||
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
|
||||
|
||||
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
|
||||
enforcement will be the only mechanism for gating EE features.
|
||||
|
||||
License Enforcement States (when enabled)
|
||||
=========================================
|
||||
For self-hosted deployments:
|
||||
|
||||
1. No license (never subscribed):
|
||||
- Allow community features (basic connectors, search, chat)
|
||||
- Block EE-only features (analytics, user groups, etc.)
|
||||
|
||||
2. GATED_ACCESS (fully expired):
|
||||
- Block all routes except billing/auth/license
|
||||
- User must renew subscription to continue
|
||||
|
||||
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
|
||||
- Full access to all EE features
|
||||
- Seat limits enforced
|
||||
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
@@ -9,38 +47,30 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
|
||||
from ee.onyx.configs.license_enforcement_config import (
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
return any(
|
||||
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
|
||||
)
|
||||
|
||||
|
||||
def _is_ee_only_path(path: str) -> bool:
|
||||
"""Check if path requires EE license (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
@@ -66,29 +96,84 @@ def add_license_enforcement_middleware(
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
|
||||
# If no cached metadata, check database (cache may have been cleared)
|
||||
if not metadata:
|
||||
logger.debug(
|
||||
"[license_enforcement] No cached license, checking database..."
|
||||
)
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
if metadata:
|
||||
logger.info(
|
||||
"[license_enforcement] Loaded license from database"
|
||||
)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"[license_enforcement] Failed to check database for license: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# User HAS a license (current or expired)
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
# License fully expired - gate the user
|
||||
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
|
||||
# they don't block access
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
else:
|
||||
# License is active - check seat limit
|
||||
# used_seats in cache is kept accurate via invalidation
|
||||
# when users are added/removed
|
||||
if metadata.used_seats > metadata.seats:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request: "
|
||||
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "seat_limit_exceeded",
|
||||
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
|
||||
"used_seats": metadata.used_seats,
|
||||
"seats": metadata.seats,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No license in cache OR database = never subscribed
|
||||
# Allow community features, but block EE-only features
|
||||
if _is_ee_only_path(path):
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking EE-only path (no license): {path}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "enterprise_license_required",
|
||||
"message": "This feature requires an Enterprise license. "
|
||||
"Please upgrade to access this functionality.",
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request (license expired): {path}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
|
||||
@@ -12,21 +12,51 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
|
||||
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
|
||||
|
||||
|
||||
def check_ee_features_enabled() -> bool:
|
||||
"""EE version: checks if EE features should be available.
|
||||
|
||||
Returns True if:
|
||||
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
|
||||
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
|
||||
- Self-hosted with a valid (non-expired) license
|
||||
|
||||
Returns False if:
|
||||
- Self-hosted with no license (never subscribed)
|
||||
- Self-hosted with expired license
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
# License enforcement disabled - allow EE features (legacy behavior)
|
||||
return True
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
|
||||
return True
|
||||
|
||||
# Self-hosted with enforcement - check for valid license
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license for EE features: {e}")
|
||||
# Fail closed - if Redis is down, other things will break anyway
|
||||
return False
|
||||
|
||||
# No license or GATED_ACCESS - no EE features
|
||||
return False
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
if the license indicates GATED_ACCESS (fully expired).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
@@ -43,11 +73,10 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
if metadata and metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
# No license = user hasn't purchased yet, allow access for upgrade flow
|
||||
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.proxy import router as proxy_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
@@ -22,3 +23,4 @@ router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
router.include_router(proxy_router)
|
||||
|
||||
@@ -1,3 +1,21 @@
|
||||
"""Billing API endpoints for cloud multi-tenant deployments.
|
||||
|
||||
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
|
||||
which provides a unified API for both self-hosted and cloud deployments.
|
||||
|
||||
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
|
||||
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
|
||||
Current endpoints to migrate:
|
||||
- GET /tenants/billing-information -> GET /admin/billing/information
|
||||
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
|
||||
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
|
||||
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
|
||||
|
||||
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
|
||||
and are NOT part of this migration - they stay here.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
@@ -90,11 +108,7 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
|
||||
485
backend/ee/onyx/server/tenants/proxy.py
Normal file
485
backend/ee/onyx/server/tenants/proxy.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Proxy endpoints for billing operations.
|
||||
|
||||
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
|
||||
for self-hosted instances to reach the control plane.
|
||||
|
||||
Flow:
|
||||
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
|
||||
|
||||
Self-hosted instances call these endpoints with their license in the Authorization
|
||||
header. The cloud data plane validates the license signature and forwards the
|
||||
request to the control plane using JWT authentication.
|
||||
|
||||
Auth levels by endpoint:
|
||||
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
|
||||
- /claim-license: Session ID based (one-time after Stripe payment)
|
||||
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
|
||||
- /billing-information: Valid license required
|
||||
- /license/{tenant_id}: Valid license required
|
||||
- /seats/update: Valid license required
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import is_license_valid
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/proxy")
|
||||
|
||||
|
||||
def _check_license_enforcement_enabled() -> None:
|
||||
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Proxy endpoints are only available on cloud data plane",
|
||||
)
|
||||
|
||||
|
||||
def _extract_license_from_header(
|
||||
authorization: str | None,
|
||||
required: bool = True,
|
||||
) -> str | None:
|
||||
"""Extract license data from Authorization header.
|
||||
|
||||
Self-hosted instances authenticate to these proxy endpoints by sending their
|
||||
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
|
||||
|
||||
We use the Bearer scheme (RFC 6750) because:
|
||||
1. It's the standard HTTP auth scheme for token-based authentication
|
||||
2. The license blob is cryptographically signed (RSA), so it's self-validating
|
||||
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
|
||||
|
||||
The license data is the base64-encoded signed blob that contains tenant_id,
|
||||
seats, expiration, etc. We verify the signature to authenticate the caller.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value (e.g., "Bearer <license>")
|
||||
required: If True, raise 401 when header is missing/invalid
|
||||
|
||||
Returns:
|
||||
License data string (base64-encoded), or None if not required and missing
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if required and header is missing/invalid
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
if required:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Missing or invalid authorization header"
|
||||
)
|
||||
return None
|
||||
|
||||
return authorization.split(" ", 1)[1]
|
||||
|
||||
|
||||
def verify_license_auth(
|
||||
license_data: str,
|
||||
allow_expired: bool = False,
|
||||
) -> LicensePayload:
|
||||
"""Verify license signature and optionally check expiry.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded signed license blob
|
||||
allow_expired: If True, accept expired licenses (for renewal flows)
|
||||
|
||||
Returns:
|
||||
LicensePayload if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If license is invalid or expired (when not allowed)
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
|
||||
|
||||
if not allow_expired and not is_license_valid(payload):
|
||||
raise HTTPException(status_code=401, detail="License has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require valid (non-expired) license.
|
||||
|
||||
Used for endpoints that require an active subscription.
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=False)
|
||||
|
||||
|
||||
async def get_license_payload_allow_expired(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require license with valid signature, expired OK.
|
||||
|
||||
Used for endpoints needed to fix payment issues (portal, renewal checkout).
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def get_optional_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload | None:
|
||||
"""Dependency: Optional license auth (for checkout - new customers have none).
|
||||
|
||||
Returns None if no license provided, otherwise validates and returns payload.
|
||||
Expired licenses are allowed for renewal flows.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
license_data = _extract_license_from_header(authorization, required=False)
|
||||
if license_data is None:
|
||||
return None
|
||||
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def forward_to_control_plane(
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Forward a request to the control plane with proper authentication."""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
detail = "Control plane request failed"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"Control plane returned {status_code}: {detail}")
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
|
||||
"""Store license in database and update Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
license_data: Base64-encoded signed license blob
|
||||
"""
|
||||
try:
|
||||
# Verify before storing
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in database using the specific tenant's schema
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
# Update Redis cache
|
||||
update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license: {e}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store license")
|
||||
raise
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
# Redirect URL after successful checkout - self-hosted passes their instance URL
|
||||
redirect_url: str | None = None
|
||||
# Cancel URL when user exits checkout - returns to upgrade page
|
||||
cancel_url: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def proxy_create_checkout_session(
|
||||
request_body: CreateCheckoutSessionRequest,
|
||||
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Proxy checkout session creation to control plane.
|
||||
|
||||
Auth: Optional license (new customers don't have one yet).
|
||||
If license provided, expired is OK (for renewals).
|
||||
"""
|
||||
# license_payload is None for new customers who don't have a license yet.
|
||||
# In that case, tenant_id is omitted from the request body and the control
|
||||
# plane will create a new tenant during checkout completion.
|
||||
tenant_id = license_payload.tenant_id if license_payload else None
|
||||
|
||||
body: dict = {
|
||||
"billing_period": request_body.billing_period,
|
||||
}
|
||||
if tenant_id:
|
||||
body["tenant_id"] = tenant_id
|
||||
if request_body.email:
|
||||
body["email"] = request_body.email
|
||||
if request_body.redirect_url:
|
||||
body["redirect_url"] = request_body.redirect_url
|
||||
if request_body.cancel_url:
|
||||
body["cancel_url"] = request_body.cancel_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-checkout-session", body=body
|
||||
)
|
||||
return CreateCheckoutSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class ClaimLicenseRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ClaimLicenseResponse(BaseModel):
|
||||
tenant_id: str
|
||||
license: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@router.post("/claim-license")
|
||||
async def proxy_claim_license(
|
||||
request_body: ClaimLicenseRequest,
|
||||
) -> ClaimLicenseResponse:
|
||||
"""Claim a license after successful Stripe checkout.
|
||||
|
||||
Auth: Session ID based (one-time use after payment).
|
||||
The control plane verifies the session_id is valid and unclaimed.
|
||||
|
||||
Returns the license to the caller. For self-hosted instances, they will
|
||||
store the license locally. The cloud DP doesn't need to store it.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/claim-license",
|
||||
body={"session_id": request_body.session_id},
|
||||
)
|
||||
|
||||
tenant_id = result.get("tenant_id")
|
||||
license_data = result.get("license")
|
||||
|
||||
if not tenant_id or not license_data:
|
||||
logger.error(f"Control plane returned incomplete claim response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
return ClaimLicenseResponse(
|
||||
tenant_id=tenant_id,
|
||||
license=license_data,
|
||||
message="License claimed successfully",
|
||||
)
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def proxy_create_customer_portal_session(
|
||||
request_body: CreateCustomerPortalSessionRequest | None = None,
|
||||
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Proxy customer portal session creation to control plane.
|
||||
|
||||
Auth: License required, expired OK (need portal to fix payment issues).
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
body: dict = {"tenant_id": tenant_id}
|
||||
if request_body and request_body.return_url:
|
||||
body["return_url"] = request_body.return_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-customer-portal-session", body=body
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: str | None = None
|
||||
current_period_end: str | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: str | None = None
|
||||
trial_start: str | None = None
|
||||
trial_end: str | None = None
|
||||
payment_method_enabled: bool = False
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def proxy_billing_information(
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> BillingInformationResponse:
|
||||
"""Proxy billing information request to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"GET", "/billing-information", params={"tenant_id": tenant_id}
|
||||
)
|
||||
# Add tenant_id from license if not in response (control plane may not include it)
|
||||
if "tenant_id" not in result:
|
||||
result["tenant_id"] = tenant_id
|
||||
return BillingInformationResponse(**result)
|
||||
|
||||
|
||||
class LicenseFetchResponse(BaseModel):
|
||||
license: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
@router.get("/license/{tenant_id}")
|
||||
async def proxy_license_fetch(
|
||||
tenant_id: str,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> LicenseFetchResponse:
|
||||
"""Proxy license fetch to control plane.
|
||||
|
||||
Auth: Valid license required.
|
||||
The tenant_id in path must match the authenticated tenant.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
if tenant_id != license_payload.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot fetch license for a different tenant",
|
||||
)
|
||||
|
||||
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
|
||||
|
||||
# Auto-store the refreshed license
|
||||
license_data = result.get("license")
|
||||
if not license_data:
|
||||
logger.error(f"Control plane returned incomplete license response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
fetch_and_store_license(tenant_id, license_data)
|
||||
|
||||
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def proxy_seat_update(
|
||||
request_body: SeatUpdateRequest,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Proxy seat update to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
Handles Stripe proration and license regeneration.
|
||||
"""
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/seats/update",
|
||||
body={
|
||||
"tenant_id": tenant_id,
|
||||
"new_seat_count": request_body.new_seat_count,
|
||||
},
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=result.get("success", False),
|
||||
current_seats=result.get("current_seats", 0),
|
||||
used_seats=result.get("used_seats", 0),
|
||||
message=result.get("message"),
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
@@ -47,6 +48,8 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
@@ -70,49 +73,104 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
|
||||
If a user already has an active mapping to a different tenant, they receive
|
||||
an inactive mapping (invitation) to this tenant. They can accept the
|
||||
invitation later to switch tenants.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if adding active users would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
|
||||
|
||||
unique_emails = set(emails)
|
||||
if not unique_emails:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
# Batch query 1: Get all existing mappings for these emails to this tenant
|
||||
# Lock rows to prevent concurrent modifications
|
||||
existing_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
emails_with_mapping = {m.email for m in existing_mappings}
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# Batch query 2: Get all active mappings for these emails (any tenant)
|
||||
active_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails_with_active_mapping = {m.email for m in active_mappings}
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
# Determine which users will consume a new seat.
|
||||
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
|
||||
# and don't consume seats until they accept. Only users without any active
|
||||
# mapping will get an ACTIVE mapping and consume a seat immediately.
|
||||
emails_consuming_seats = {
|
||||
email
|
||||
for email in unique_emails
|
||||
if email not in emails_with_mapping
|
||||
and email not in emails_with_active_mapping
|
||||
}
|
||||
|
||||
# Check seat availability inside the transaction to prevent race conditions.
|
||||
# Note: ALL users in unique_emails still get added below - this check only
|
||||
# validates we have capacity for users who will consume seats immediately.
|
||||
if emails_consuming_seats:
|
||||
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session,
|
||||
seats_needed=len(emails_consuming_seats),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# Add mappings for emails that don't already have one to this tenant
|
||||
for email in unique_emails:
|
||||
if email in emails_with_mapping:
|
||||
continue
|
||||
|
||||
# Create mapping: inactive if user belongs to another tenant (invitation),
|
||||
# active otherwise
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=email not in emails_with_active_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
except HTTPException:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
@@ -135,6 +193,9 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
@@ -149,6 +210,9 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -177,6 +241,9 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
@@ -195,19 +262,42 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if accepting would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
# Lock the user's mappings first to prevent race conditions.
|
||||
# This ensures no concurrent request can modify this user's mappings
|
||||
# while we check seats and activate.
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check seat availability within the same logical operation.
|
||||
# Note: This queries fresh data from DB, not cache.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session, seats_needed=1, tenant_id=tenant_id
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
@@ -237,6 +327,9 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
@@ -297,16 +390,41 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
Get the number of active users for this tenant.
|
||||
|
||||
A user counts toward the seat count if:
|
||||
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
|
||||
2. AND the User is active (User.is_active == True)
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
from onyx.db.models import User
|
||||
|
||||
# First get all emails with active mappings to this tenant
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
active_mapping_emails = (
|
||||
db_session.query(UserTenantMapping.email)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails = [email for (email,) in active_mapping_emails]
|
||||
|
||||
if not emails:
|
||||
return 0
|
||||
|
||||
# Now count how many of those users are actually active in the tenant's User table
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
@@ -19,21 +20,27 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
# Path to the license public key file
|
||||
_LICENSE_PUBLIC_KEY_PATH = (
|
||||
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
|
||||
)
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
"""Load the public key from file, with env var override."""
|
||||
# Allow env var override for flexibility
|
||||
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
|
||||
|
||||
if not key_pem:
|
||||
# Read from file
|
||||
if not _LICENSE_PUBLIC_KEY_PATH.exists():
|
||||
raise ValueError(
|
||||
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
|
||||
|
||||
key = serialization.load_pem_public_key(key_pem.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
@@ -53,17 +60,21 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
|
||||
# Parse into LicenseData to validate structure
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
|
||||
# not re-serialized through Pydantic. Pydantic may format fields differently
|
||||
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
|
||||
original_payload = decoded.get("payload", {})
|
||||
payload_json = json.dumps(original_payload, sort_keys=True)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
@@ -77,16 +88,18 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
logger.error("[verify_license] FAILED: Signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
logger.error(
|
||||
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
|
||||
)
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
logger.exception("[verify_license] FAILED: Unexpected error")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -20,7 +21,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
14
backend/keys/license_public_key.pem
Normal file
14
backend/keys/license_public_key.pem
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
|
||||
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
|
||||
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
|
||||
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
|
||||
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
|
||||
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
|
||||
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
|
||||
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
|
||||
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
|
||||
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
|
||||
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
|
||||
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
98
backend/onyx/background/README.md
Normal file
98
backend/onyx/background/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat)
|
||||
5. Reporting metrics on things like queue length for monitoring purposes
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
`app_base.py` provides:
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
- mark orphaned index attempts failed
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
|
||||
### Light
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
|
||||
### Heavy
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
|
||||
Generates CSV exports which may take a long time with significant data in Postgres.
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
|
||||
### Monitoring
|
||||
Observability and metrics collections:
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
@@ -43,6 +43,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.tracing.braintrust_tracing import setup_braintrust_if_creds_available
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -237,6 +238,9 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
# Initialize Braintrust tracing in workers if credentials are available.
|
||||
setup_braintrust_if_creds_available()
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
|
||||
@@ -121,7 +121,6 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
@@ -134,5 +133,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
|
||||
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
@@ -116,5 +116,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -318,12 +318,12 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -32,10 +33,16 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
doc_batch: (
|
||||
Iterator[list[Document | HierarchyNode]]
|
||||
| Iterator[list[SlimDocument | HierarchyNode]]
|
||||
),
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
yield {
|
||||
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
|
||||
for doc in doc_list
|
||||
}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -57,24 +57,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing-clustering-only",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
"schedule": timedelta(seconds=600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
@@ -129,6 +111,15 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-hierarchy-fetching",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
|
||||
"schedule": timedelta(hours=1), # Check hourly, but only fetch once per day
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
@@ -139,6 +130,27 @@ beat_task_templates: list[dict] = [
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
# Sandbox cleanup tasks
|
||||
{
|
||||
"name": "cleanup-idle-sandboxes",
|
||||
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cleanup-old-snapshots",
|
||||
"task": OnyxCeleryTask.CLEANUP_OLD_SNAPSHOTS,
|
||||
"schedule": timedelta(hours=24),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
|
||||
278
backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py
Normal file
278
backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Celery tasks for hierarchy fetching.
|
||||
|
||||
This module provides tasks for fetching hierarchy node information from connectors.
|
||||
Hierarchy nodes represent structural elements like folders, spaces, and pages that
|
||||
can be used to filter search results.
|
||||
|
||||
The hierarchy fetching pipeline runs once per day per connector and fetches
|
||||
structural information from the connector source.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import mark_cc_pair_as_hierarchy_fetched
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Hierarchy fetching runs once per day (24 hours in seconds)
|
||||
HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60
|
||||
|
||||
|
||||
def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if hierarchy fetching is due for this connector.
|
||||
|
||||
Hierarchy fetching should run once per day for active connectors.
|
||||
"""
|
||||
# Skip if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# Skip if connector has never successfully indexed
|
||||
if not cc_pair.last_successful_index_time:
|
||||
return False
|
||||
|
||||
# Check if we've fetched hierarchy recently
|
||||
last_fetch = cc_pair.last_time_hierarchy_fetch
|
||||
if last_fetch is None:
|
||||
# Never fetched before - fetch now
|
||||
return True
|
||||
|
||||
# Check if enough time has passed since last fetch
|
||||
next_fetch_time = last_fetch + timedelta(seconds=HIERARCHY_FETCH_INTERVAL_SECONDS)
|
||||
return datetime.now(timezone.utc) >= next_fetch_time
|
||||
|
||||
|
||||
def _try_creating_hierarchy_fetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Try to create a hierarchy fetching task for a connector.
|
||||
|
||||
Returns the task ID if created, None otherwise.
|
||||
"""
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# Serialize task creation attempts
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + f"hierarchy_fetching_{cc_pair.id}",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Refresh to get latest state
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate task ID
|
||||
custom_task_id = f"hierarchy_fetching_{cc_pair.id}_{uuid4()}"
|
||||
|
||||
# Send the task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise RuntimeError("send_task for hierarchy_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created hierarchy fetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return custom_task_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to create hierarchy fetching task: cc_pair={cc_pair.id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""Check for connectors that need hierarchy fetching and spawn tasks.
|
||||
|
||||
This task runs periodically (once per day) and checks all active connectors
|
||||
to see if they need hierarchy information fetched.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
task_logger.info("check_for_hierarchy_fetching - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_HIERARCHY_FETCHING_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all active connector credential pairs
|
||||
cc_pair_ids = fetch_indexable_standard_connector_credential_pair_ids(
|
||||
db_session=db_session,
|
||||
active_cc_pairs_only=True,
|
||||
)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair or not _is_hierarchy_fetching_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_id = _try_creating_hierarchy_fetching_task(
|
||||
celery_app=self.app,
|
||||
cc_pair=cc_pair,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if task_id:
|
||||
tasks_created += 1
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("check_for_hierarchy_fetching - Unexpected error")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_hierarchy_fetching - Lock not owned on completion"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_hierarchy_fetching finished: "
|
||||
f"tasks_created={tasks_created} elapsed={time_elapsed:.2f}s"
|
||||
)
|
||||
return tasks_created
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
|
||||
soft_time_limit=3600, # 1 hour soft limit
|
||||
time_limit=3900, # 1 hour 5 min hard limit
|
||||
bind=True,
|
||||
)
|
||||
def connector_hierarchy_fetching_task(
|
||||
self: Task,
|
||||
*,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""Fetch hierarchy information from a connector.
|
||||
|
||||
This task fetches structural information (folders, spaces, pages, etc.)
|
||||
from the connector source and stores it in the database.
|
||||
"""
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task starting: "
|
||||
f"cc_pair={cc_pair_id} tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"CC pair not found for hierarchy fetching: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.info(
|
||||
f"Skipping hierarchy fetching for deleting connector: "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# TODO: Implement the actual hierarchy fetching logic
|
||||
# This will involve:
|
||||
# 1. Instantiating the connector
|
||||
# 2. Calling a hierarchy-specific method on the connector
|
||||
# 3. Upserting the hierarchy nodes to the database
|
||||
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task: "
|
||||
f"Hierarchy fetching not yet implemented for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# Update the last fetch time to prevent re-running until next interval
|
||||
mark_cc_pair_as_hierarchy_fetched(db_session, cc_pair_id)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"connector_hierarchy_fetching_task failed: cc_pair={cc_pair_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task completed: cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.apps.client import celery_app
|
||||
from onyx.background.celery.tasks.kg_processing.utils import is_kg_processing_blocked
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_processing_requirements_met,
|
||||
)
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
|
||||
def try_creating_kg_processing_task(
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Schedules the KG Processing for a tenant immediately. Will not schedule if
|
||||
the tenant is not ready for KG processing.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not is_kg_processing_requirements_met():
|
||||
return False
|
||||
|
||||
# Send the KG processing task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.KG_PROCESSING,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
if not result:
|
||||
task_logger.error("send_task for kg processing failed.")
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_kg_processing_task - Unexpected exception for tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def try_creating_kg_source_reset_task(
|
||||
tenant_id: str,
|
||||
source_name: str | None,
|
||||
index_name: str,
|
||||
) -> bool:
|
||||
"""Schedules the KG Source Reset for a tenant immediately. Will not do anything if
|
||||
the tenant is currently being processed.
|
||||
"""
|
||||
|
||||
try:
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
# Send the KG source reset task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
source_name=source_name,
|
||||
index_name=index_name,
|
||||
),
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
if not result:
|
||||
task_logger.error("send_task for kg source reset failed.")
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_kg_source_reset_task - Unexpected exception for tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
@@ -1,253 +0,0 @@
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_clustering_only_requirements_met,
|
||||
)
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_processing_requirements_met,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.kg.resets.reset_source import reset_source_kg_index
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""a lightweight task used to kick off kg processing tasks."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing - Starting")
|
||||
try:
|
||||
if not is_kg_processing_requirements_met():
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG processing for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_PROCESSING,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg processing check")
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_kg_processing finished: elapsed={time_elapsed:.2f}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing_clustering_only(self: Task, *, tenant_id: str) -> None:
|
||||
"""a lightweight task used to kick off kg clustering tasks."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing_clustering_only - Starting")
|
||||
|
||||
try:
|
||||
if not is_kg_clustering_only_requirements_met():
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG clustering for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg clustering-only check")
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_kg_processing_clustering_only finished: elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_PROCESSING,
|
||||
bind=True,
|
||||
)
|
||||
def kg_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""a task for doing kg extraction and clustering."""
|
||||
|
||||
task_logger.warning(f"kg_processing - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
|
||||
|
||||
kg_extraction(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg processing")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_processing - Lock not owned on completion: " f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg processing task!")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
bind=True,
|
||||
)
|
||||
def kg_clustering_only(self: Task, *, tenant_id: str) -> None:
|
||||
"""a task for doing kg clustering only."""
|
||||
|
||||
task_logger.warning(f"kg_clustering_only - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
task_logger.info(
|
||||
f"KG clustering-only set to in progress for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg clustering-only")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_clustering_only - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg clustering-only task!")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
|
||||
bind=True,
|
||||
)
|
||||
def kg_reset_source_index(
|
||||
self: Task, *, tenant_id: str, source_name: str, index_name: str
|
||||
) -> None:
|
||||
"""a task for KG reset of a source."""
|
||||
|
||||
task_logger.warning(f"kg_reset_source_index - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
reset_source_kg_index(
|
||||
source_name=source_name,
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_name,
|
||||
lock=lock_beat,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg reset")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_reset_source_index - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg reset task!")
|
||||
@@ -1,78 +0,0 @@
|
||||
import time
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import check_for_documents_needing_kg_processing
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
def is_kg_processing_blocked() -> bool:
|
||||
"""Checks if there are any KG tasks in progress."""
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(OnyxRedisLocks.KG_PROCESSING_LOCK)
|
||||
return lock_beat.locked()
|
||||
|
||||
|
||||
def is_kg_processing_requirements_met() -> bool:
|
||||
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
|
||||
and there are documents that need KG processing."""
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
kg_config = get_kg_config_settings()
|
||||
if not is_kg_config_settings_enabled_valid(kg_config):
|
||||
return False
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
has_staging_entities = (
|
||||
db_session.query(KGEntityExtractionStaging).first() is not None
|
||||
)
|
||||
has_staging_relationships = (
|
||||
db_session.query(KGRelationshipExtractionStaging).first() is not None
|
||||
)
|
||||
return (
|
||||
check_for_documents_needing_kg_processing(
|
||||
db_session,
|
||||
kg_config.KG_COVERAGE_START_DATE,
|
||||
kg_config.KG_MAX_COVERAGE_DAYS,
|
||||
)
|
||||
or has_staging_entities
|
||||
or has_staging_relationships
|
||||
)
|
||||
|
||||
|
||||
def is_kg_clustering_only_requirements_met() -> bool:
|
||||
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
|
||||
and there are documents that need KG clustering."""
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
kg_config = get_kg_config_settings()
|
||||
if not is_kg_config_settings_enabled_valid(kg_config):
|
||||
return False
|
||||
|
||||
# Check if there are any entries in the staging tables
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
has_staging_entities = (
|
||||
db_session.query(KGEntityExtractionStaging).first() is not None
|
||||
)
|
||||
has_staging_relationships = (
|
||||
db_session.query(KGRelationshipExtractionStaging).first() is not None
|
||||
)
|
||||
|
||||
return has_staging_entities or has_staging_relationships
|
||||
|
||||
|
||||
def extend_lock(lock: RedisLock, timeout: int, last_lock_time: float) -> float:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (timeout / 4):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
return last_lock_time
|
||||
@@ -27,6 +27,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
@@ -232,7 +233,9 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(batch)
|
||||
documents.extend(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Factory stub for running celery worker for knowledge graph processing.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.kg_processing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -31,17 +31,21 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
@@ -53,7 +57,15 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -367,6 +379,7 @@ def connector_document_extraction(
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
processing_mode = index_attempt.connector_credential_pair.processing_mode
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
@@ -534,9 +547,12 @@ def connector_document_extraction(
|
||||
logger.info(
|
||||
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
for (
|
||||
document_batch,
|
||||
hierarchy_node_batch,
|
||||
failure,
|
||||
next_checkpoint,
|
||||
) in connector_runner.run(checkpoint):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
@@ -571,6 +587,33 @@ def connector_document_extraction(
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=db_connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
if not document_batch:
|
||||
continue
|
||||
@@ -600,34 +643,103 @@ def connector_document_extraction(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
# cc4a
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
# File system only - write directly to persistent storage,
|
||||
# skip chunking/embedding/Vespa but still track documents in DB
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Create metadata for the batch
|
||||
index_attempt_metadata = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
request_id=make_randomized_onyx_request_id("FSI"),
|
||||
structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}",
|
||||
batch_num=batch_num,
|
||||
)
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
# Upsert documents to PostgreSQL (document table + cc_pair relationship)
|
||||
# This is a subset of what docprocessing does - just DB tracking, no chunking/embedding
|
||||
index_doc_batch_prepare(
|
||||
documents=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True, # Documents already filtered during extraction
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
# Mark documents as indexed for the CC pair
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_ids=[doc.id for doc in doc_batch_cleaned],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
# Write documents to persistent file system
|
||||
# Use creator_id for user-segregated storage paths (sandbox isolation)
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id is None:
|
||||
raise ValueError(
|
||||
f"ConnectorCredentialPair {index_attempt.connector_credential_pair.id} "
|
||||
"must have a creator_id for persistent document storage"
|
||||
)
|
||||
user_id_str: str = str(creator_id)
|
||||
writer = get_persistent_document_writer(
|
||||
user_id=user_id_str,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
written_paths = writer.write_documents(doc_batch_cleaned)
|
||||
|
||||
# Update coordination directly (no docprocessing task)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
IndexingCoordination.update_batch_completion_and_docs(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=len(doc_batch_cleaned),
|
||||
new_docs_indexed=len(doc_batch_cleaned),
|
||||
total_chunks=0, # No chunks for file system mode
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Wrote documents to file system: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(written_paths)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
else:
|
||||
# REGULAR mode (default): Full pipeline - store and queue docprocessing
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
@@ -663,6 +775,24 @@ def connector_document_extraction(
|
||||
total_batches=batch_num,
|
||||
)
|
||||
|
||||
# Trigger file sync to user's sandbox (if running) - only for FILE_SYSTEM mode
|
||||
# This syncs the newly written documents from S3 to any running sandbox pod
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id:
|
||||
app.send_task(
|
||||
OnyxCeleryTask.SANDBOX_FILE_SYNC,
|
||||
kwargs={
|
||||
"user_id": str(creator_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
|
||||
@@ -9,12 +9,6 @@ from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_processing_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_source_reset_task,
|
||||
)
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
@@ -37,7 +31,6 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -359,39 +352,7 @@ def process_kg_commands(
|
||||
if not is_kg_config_settings_enabled_valid(kg_config_settings):
|
||||
return
|
||||
|
||||
# get Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
if message == "kg_p":
|
||||
success = try_creating_kg_processing_task(tenant_id)
|
||||
if success:
|
||||
raise KGException("KG processing scheduled")
|
||||
else:
|
||||
raise KGException(
|
||||
"Cannot schedule another KG processing if one is already running "
|
||||
"or there are no documents to process"
|
||||
)
|
||||
|
||||
elif message.startswith("kg_rs_source"):
|
||||
msg_split = [x for x in message.split(":")]
|
||||
if len(msg_split) > 2:
|
||||
raise KGException("Invalid format for a source reset command")
|
||||
elif len(msg_split) == 2:
|
||||
source_name = msg_split[1].strip()
|
||||
elif len(msg_split) == 1:
|
||||
source_name = None
|
||||
else:
|
||||
raise KGException("Invalid format for a source reset command")
|
||||
|
||||
success = try_creating_kg_source_reset_task(tenant_id, source_name, index_str)
|
||||
if success:
|
||||
source_name = source_name or "all"
|
||||
raise KGException(f"KG index reset for source '{source_name}' scheduled")
|
||||
else:
|
||||
raise KGException("Cannot reset index while KG processing is running")
|
||||
|
||||
elif message == "kg_setup":
|
||||
if message == "kg_setup":
|
||||
populate_missing_default_entity_types__commit(db_session=db_session)
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
@@ -207,6 +207,9 @@ OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
@@ -410,7 +413,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
@@ -422,10 +425,6 @@ CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
CELERY_WORKER_MONITORING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
|
||||
)
|
||||
@@ -1042,3 +1041,14 @@ STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
# Persistent Document Storage Configuration
|
||||
# When enabled, indexed documents are written to local filesystem with hierarchical structure
|
||||
PERSISTENT_DOCUMENT_STORAGE_ENABLED = (
|
||||
os.environ.get("PERSISTENT_DOCUMENT_STORAGE_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Base directory path for persistent document storage (local filesystem)
|
||||
# Example: /var/onyx/indexed-docs or /app/indexed-docs
|
||||
PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get(
|
||||
"PERSISTENT_DOCUMENT_STORAGE_PATH", "/app/indexed-docs"
|
||||
)
|
||||
|
||||
@@ -80,7 +80,6 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
"celery_worker_user_file_processing"
|
||||
@@ -241,6 +240,7 @@ class NotificationType(str, Enum):
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -327,6 +327,7 @@ class FileOrigin(str, Enum):
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
QUERY_HISTORY_CSV = "query_history_csv"
|
||||
SANDBOX_SNAPSHOT = "sandbox_snapshot"
|
||||
USER_FILE = "user_file"
|
||||
|
||||
|
||||
@@ -344,6 +345,7 @@ class MilestoneRecordType(str, Enum):
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
@@ -367,6 +369,7 @@ class OnyxCeleryQueues:
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
CONNECTOR_HIERARCHY_FETCHING = "connector_hierarchy_fetching"
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# User file processing queue
|
||||
@@ -380,8 +383,8 @@ class OnyxCeleryQueues:
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
# KG processing queue
|
||||
KG_PROCESSING = "kg_processing"
|
||||
# Sandbox processing queue
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
@@ -389,6 +392,7 @@ class OnyxRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_HIERARCHY_FETCHING_BEAT_LOCK = "da_lock:check_hierarchy_fetching_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
|
||||
CHECK_INDEX_ATTEMPT_CLEANUP_BEAT_LOCK = "da_lock:check_index_attempt_cleanup_beat"
|
||||
@@ -417,9 +421,6 @@ class OnyxRedisLocks:
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
|
||||
# KG processing
|
||||
KG_PROCESSING_LOCK = "da_lock:kg_processing"
|
||||
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
@@ -431,6 +432,10 @@ class OnyxRedisLocks:
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
|
||||
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -447,9 +452,6 @@ class OnyxRedisSignals:
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING_BEAT_LOCK = "da_lock:check_kg_processing_beat"
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
ACTIVE_FENCES = "active_fences"
|
||||
@@ -493,6 +495,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_HIERARCHY_FETCHING = "check_for_hierarchy_fetching"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_AUTO_LLM_UPDATE = "check_for_auto_llm_update"
|
||||
@@ -534,6 +537,7 @@ class OnyxCeleryTask:
|
||||
DOCPROCESSING_TASK = "docprocessing_task"
|
||||
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
CONNECTOR_HIERARCHY_FETCHING_TASK = "connector_hierarchy_fetching_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
@@ -549,12 +553,12 @@ class OnyxCeleryTask:
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING = "check_kg_processing"
|
||||
KG_PROCESSING = "kg_processing"
|
||||
KG_CLUSTERING_ONLY = "kg_clustering_only"
|
||||
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
|
||||
KG_RESET_SOURCE_INDEX = "kg_reset_source_index"
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
# Sandbox file sync
|
||||
SANDBOX_FILE_SYNC = "sandbox_file_sync"
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -419,7 +420,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
record_documents: list[Document] = []
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -56,7 +57,7 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
docs_batch: list[Document] = []
|
||||
docs_batch: list[Document | HierarchyNode] = []
|
||||
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
|
||||
|
||||
for task in tasks:
|
||||
@@ -116,5 +117,8 @@ if __name__ == "__main__":
|
||||
latest_docs = connector.poll_source(one_day_ago, current)
|
||||
for docs in latest_docs:
|
||||
for doc in docs:
|
||||
print(doc.id)
|
||||
if isinstance(doc, HierarchyNode):
|
||||
print("hierarchynode:", doc.display_name)
|
||||
else:
|
||||
print(doc.id)
|
||||
logger.notice("Asana connector test completed")
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -271,9 +272,9 @@ class BitbucketConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Iterator[list[SlimDocument]]:
|
||||
) -> Iterator[list[SlimDocument | HierarchyNode]]:
|
||||
"""Return only document IDs for all existing pull requests."""
|
||||
batch: list[SlimDocument] = []
|
||||
batch: list[SlimDocument | HierarchyNode] = []
|
||||
params = self._build_params(
|
||||
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
|
||||
start=start,
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
@@ -377,7 +378,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
batch: list[Document | HierarchyNode] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
@@ -616,6 +617,10 @@ if __name__ == "__main__":
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
print("hierarchynode:", doc.display_name)
|
||||
continue
|
||||
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
|
||||
@@ -47,7 +48,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
start_ind: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
) -> tuple[list[Document | HierarchyNode], int]:
|
||||
params = {
|
||||
"count": str(batch_size),
|
||||
"offset": str(start_ind),
|
||||
@@ -65,7 +66,9 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
doc_batch = [transformer(bookstack_client, item) for item in batch]
|
||||
doc_batch: list[Document | HierarchyNode] = [
|
||||
transformer(bookstack_client, item) for item in batch
|
||||
]
|
||||
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -80,7 +81,7 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
page: int = 0
|
||||
params = {
|
||||
"include_markdown_description": "true",
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -382,7 +383,7 @@ class ConfluenceConnector(
|
||||
page: dict[str, Any],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], list[ConnectorFailure]]:
|
||||
) -> tuple[list[Document | HierarchyNode], list[ConnectorFailure]]:
|
||||
"""
|
||||
Inline attachments are added directly to the document as text or image sections by
|
||||
this function. The returned documents/connectorfailures are for non-inline attachments
|
||||
@@ -392,7 +393,7 @@ class ConfluenceConnector(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
attachment_failures: list[ConnectorFailure] = []
|
||||
attachment_docs: list[Document] = []
|
||||
attachment_docs: list[Document | HierarchyNode] = []
|
||||
page_url = ""
|
||||
|
||||
try:
|
||||
@@ -700,7 +701,7 @@ class ConfluenceConnector(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
doc_metadata_list: list[SlimDocument | HierarchyNode] = []
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
space_level_access_info: dict[str, ExternalAccess] = {}
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -30,15 +31,16 @@ def batched_doc_ids(
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
batch: set[str] = set()
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
for document, hierarchy_node, failure, next_checkpoint in CheckpointOutputWrapper[
|
||||
CT
|
||||
]()(checkpoint_connector_generator):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure and failure.failed_document and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
# HierarchyNodes don't have IDs that need to be batched for doc processing
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
@@ -63,7 +65,9 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
tuple[
|
||||
Document | None, HierarchyNode | None, ConnectorFailure | None, CT | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
@@ -74,22 +78,22 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
for item in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(item, Document):
|
||||
yield item, None, None, None
|
||||
elif isinstance(item, HierarchyNode):
|
||||
yield None, item, None, None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
yield None, None, item, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
raise ValueError(f"Invalid connector output type: {type(item)}")
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
yield None, None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner(Generic[CT]):
|
||||
@@ -119,13 +123,27 @@ class ConnectorRunner(Generic[CT]):
|
||||
self.include_permissions = include_permissions
|
||||
|
||||
self.doc_batch: list[Document] = []
|
||||
self.hierarchy_node_batch: list[HierarchyNode] = []
|
||||
|
||||
def run(self, checkpoint: CT) -> Generator[
|
||||
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
|
||||
tuple[
|
||||
list[Document] | None,
|
||||
list[HierarchyNode] | None,
|
||||
ConnectorFailure | None,
|
||||
CT | None,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
"""
|
||||
Yields batches of Documents, HierarchyNodes, failures, and checkpoints.
|
||||
|
||||
Returns tuples of:
|
||||
- (doc_batch, None, None, None) - batch of documents
|
||||
- (None, hierarchy_batch, None, None) - batch of hierarchy nodes
|
||||
- (None, None, failure, None) - a connector failure
|
||||
- (None, None, None, checkpoint) - new checkpoint
|
||||
"""
|
||||
try:
|
||||
if isinstance(self.connector, CheckpointedConnector):
|
||||
if self.time_range is None:
|
||||
@@ -151,25 +169,47 @@ class ConnectorRunner(Generic[CT]):
|
||||
)
|
||||
next_checkpoint: CT | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None and isinstance(document, Document):
|
||||
for (
|
||||
document,
|
||||
hierarchy_node,
|
||||
failure,
|
||||
next_checkpoint,
|
||||
) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
yield None, failure, None
|
||||
if hierarchy_node is not None:
|
||||
self.hierarchy_node_batch.append(hierarchy_node)
|
||||
|
||||
if failure is not None:
|
||||
yield None, None, failure, None
|
||||
|
||||
# Yield hierarchy nodes batch if it reaches batch_size
|
||||
# (yield nodes before docs to maintain parent-before-child invariant)
|
||||
if len(self.hierarchy_node_batch) >= self.batch_size:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
|
||||
# Yield document batch if it reaches batch_size
|
||||
# First flush any pending hierarchy nodes to ensure parents exist
|
||||
if len(self.doc_batch) >= self.batch_size:
|
||||
yield self.doc_batch, None, None
|
||||
if len(self.hierarchy_node_batch) > 0:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
yield self.doc_batch, None, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
# yield remaining hierarchy nodes first (parents before children)
|
||||
if len(self.hierarchy_node_batch) > 0:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
|
||||
# yield remaining documents
|
||||
if len(self.doc_batch) > 0:
|
||||
yield self.doc_batch, None, None
|
||||
yield self.doc_batch, None, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
yield None, None, next_checkpoint
|
||||
yield None, None, None, next_checkpoint
|
||||
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
|
||||
@@ -183,18 +223,26 @@ class ConnectorRunner(Generic[CT]):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
for document_batch in self.connector.poll_source(
|
||||
for batch in self.connector.poll_source(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
):
|
||||
yield document_batch, None, None
|
||||
docs, nodes = self._separate_batch(batch)
|
||||
if nodes:
|
||||
yield None, nodes, None, None
|
||||
if docs:
|
||||
yield docs, None, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
yield None, None, None, finished_checkpoint
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
for document_batch in self.connector.load_from_state():
|
||||
yield document_batch, None, None
|
||||
for batch in self.connector.load_from_state():
|
||||
docs, nodes = self._separate_batch(batch)
|
||||
if nodes:
|
||||
yield None, nodes, None, None
|
||||
if docs:
|
||||
yield docs, None, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
yield None, None, None, finished_checkpoint
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
except Exception:
|
||||
@@ -219,3 +267,16 @@ class ConnectorRunner(Generic[CT]):
|
||||
f"local_vars below -> \n{local_vars_str[:1024]}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _separate_batch(
|
||||
self, batch: list[Document | HierarchyNode]
|
||||
) -> tuple[list[Document], list[HierarchyNode]]:
|
||||
"""Separate a mixed batch into Documents and HierarchyNodes."""
|
||||
docs: list[Document] = []
|
||||
nodes: list[HierarchyNode] = []
|
||||
for item in batch:
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
elif isinstance(item, HierarchyNode):
|
||||
nodes.append(item)
|
||||
return docs, nodes
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -278,7 +279,7 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -193,7 +194,7 @@ class DiscourseConnector(PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 0
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -119,7 +120,7 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
workspace_id = self._get_workspace_id_by_name()
|
||||
articles = self._get_articles_with_category(workspace_id)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
doc_batch: List[Document | HierarchyNode] = []
|
||||
|
||||
for article in articles:
|
||||
article_details = self._make_request(
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -80,7 +81,7 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
batch: list[Document | HierarchyNode] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -740,7 +741,7 @@ class DrupalWikiConnector(
|
||||
Returns:
|
||||
Generator yielding batches of SlimDocument objects.
|
||||
"""
|
||||
slim_docs: list[SlimDocument] = []
|
||||
slim_docs: list[SlimDocument | HierarchyNode] = []
|
||||
logger.info(
|
||||
f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -278,8 +279,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
current_batch: list[Document] = []
|
||||
) -> Generator[list[Document | HierarchyNode], None, None]:
|
||||
current_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
# Iterate through yielded files and filter them
|
||||
for file in self._get_files_list(self.folder_path):
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
@@ -262,7 +263,7 @@ class LocalFileConnector(LoadConnector):
|
||||
Iterates over each file path, fetches from Postgres, tries to parse text
|
||||
or images, and yields Document batches.
|
||||
"""
|
||||
documents: list[Document] = []
|
||||
documents: list[Document | HierarchyNode] = []
|
||||
|
||||
for file_id in self.file_locations:
|
||||
file_store = get_default_file_store()
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -89,6 +90,9 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
# Build hierarchy based on meeting date (year-month)
|
||||
year_month = meeting_date.strftime("%Y-%m")
|
||||
|
||||
meeting_organizer_email = transcript["organizer_email"]
|
||||
organizer_email_user_info = [BasicExpertInfo(email=meeting_organizer_email)]
|
||||
|
||||
@@ -102,6 +106,14 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [year_month],
|
||||
"year_month": year_month,
|
||||
"meeting_title": meeting_title,
|
||||
"organizer_email": meeting_organizer_email,
|
||||
}
|
||||
},
|
||||
metadata={
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
@@ -183,7 +195,7 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def _process_transcripts(
|
||||
self, start: str | None = None, end: str | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
doc_batch: List[Document | HierarchyNode] = []
|
||||
|
||||
for transcript_batch in self._fetch_transcripts(start, end):
|
||||
for transcript in transcript_batch:
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -245,7 +246,7 @@ class FreshdeskConnector(PollConnector, LoadConnector):
|
||||
def _process_tickets(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
doc_batch: List[Document | HierarchyNode] = []
|
||||
|
||||
for ticket_batch in self._fetch_tickets(start, end):
|
||||
for ticket in ticket_batch:
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -230,7 +231,7 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
|
||||
pages: list[dict[str, Any]] = content.get("pages", [])
|
||||
current_batch: list[Document] = []
|
||||
current_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
logger.info(f"Found {len(pages)} root pages.")
|
||||
logger.info(
|
||||
|
||||
@@ -240,8 +240,21 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _convert_pr_to_document(
|
||||
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
repo_full_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
# Split full_name (e.g., "owner/repo") into owner and repo
|
||||
parts = repo_full_name.split("/", 1)
|
||||
owner_name = parts[0] if parts else ""
|
||||
repo_name = parts[1] if len(parts) > 1 else repo_full_name
|
||||
|
||||
doc_metadata = {
|
||||
"repo": repo_full_name,
|
||||
"hierarchy": {
|
||||
"source_path": [owner_name, repo_name, "pull_requests"],
|
||||
"owner": owner_name,
|
||||
"repo": repo_name,
|
||||
"object_type": "pull_request",
|
||||
},
|
||||
}
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[
|
||||
@@ -259,7 +272,7 @@ def _convert_pr_to_document(
|
||||
else None
|
||||
),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -316,8 +329,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
def _convert_issue_to_document(
|
||||
issue: Issue, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = issue.repository.full_name if issue.repository else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
repo_full_name = issue.repository.full_name if issue.repository else ""
|
||||
# Split full_name (e.g., "owner/repo") into owner and repo
|
||||
parts = repo_full_name.split("/", 1)
|
||||
owner_name = parts[0] if parts else ""
|
||||
repo_name = parts[1] if len(parts) > 1 else repo_full_name
|
||||
|
||||
doc_metadata = {
|
||||
"repo": repo_full_name,
|
||||
"hierarchy": {
|
||||
"source_path": [owner_name, repo_name, "issues"],
|
||||
"owner": owner_name,
|
||||
"repo": repo_name,
|
||||
"object_type": "issue",
|
||||
},
|
||||
}
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
@@ -327,7 +353,7 @@ def _convert_issue_to_document(
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -943,7 +969,9 @@ if __name__ == "__main__":
|
||||
|
||||
# Run the connector
|
||||
while checkpoint.has_more:
|
||||
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
|
||||
for doc_batch, hierarchy_node_batch, failure, next_checkpoint in runner.run(
|
||||
checkpoint
|
||||
):
|
||||
if doc_batch:
|
||||
print(f"Retrieved batch of {len(doc_batch)} documents")
|
||||
for doc in doc_batch:
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -167,7 +168,7 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
current_path = queue.popleft()
|
||||
files = project.repository_tree(path=current_path, all=True)
|
||||
for file_batch in _batch_gitlab_objects(files, self.batch_size):
|
||||
code_doc_batch: list[Document] = []
|
||||
code_doc_batch: list[Document | HierarchyNode] = []
|
||||
for file in file_batch:
|
||||
if _should_exclude(file["path"]):
|
||||
continue
|
||||
@@ -197,7 +198,7 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size):
|
||||
mr_doc_batch: list[Document] = []
|
||||
mr_doc_batch: list[Document | HierarchyNode] = []
|
||||
for mr in mr_batch:
|
||||
mr.updated_at = datetime.strptime(
|
||||
mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
@@ -216,7 +217,7 @@ class GitlabConnector(LoadConnector, PollConnector):
|
||||
issues = project.issues.list(state=self.state_filter, iterator=True)
|
||||
|
||||
for issue_batch in _batch_gitlab_objects(issues, self.batch_size):
|
||||
issue_doc_batch: list[Document] = []
|
||||
issue_doc_batch: list[Document | HierarchyNode] = []
|
||||
for issue in issue_batch:
|
||||
issue.updated_at = datetime.strptime(
|
||||
issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z"
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -390,7 +391,9 @@ class GmailConnector(
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
get a 404 or 403, fall back to using the single user.
|
||||
A 404 indicates a personal Gmail account with no Workspace domain.
|
||||
A 403 indicates insufficient permissions (e.g., OAuth user without admin privileges).
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -413,6 +416,13 @@ class GmailConnector(
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
elif e.resp.status == 403:
|
||||
logger.warning(
|
||||
"Received 403 from Admin SDK; this may indicate insufficient permissions "
|
||||
"(e.g., OAuth user without admin privileges or service account without "
|
||||
"domain-wide delegation). Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
def _fetch_threads_impl(
|
||||
@@ -426,7 +436,7 @@ class GmailConnector(
|
||||
is_slim: bool = False,
|
||||
) -> Iterator[Document | ConnectorFailure] | GenerateSlimDocumentOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
slim_doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
logger.info(
|
||||
f"Fetching {'slim' if is_slim else 'full'} threads for user: {user_email}"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -181,7 +182,7 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
for transcript_batch in self._get_transcript_batches(
|
||||
start_datetime, end_datetime
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
transcript_call_ids = cast(
|
||||
list[str],
|
||||
|
||||
@@ -73,6 +73,8 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -1358,7 +1360,7 @@ class GoogleDriveConnector(
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
slim_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for file in self._fetch_drive_items(
|
||||
field_type=DriveFileFieldType.SLIM,
|
||||
checkpoint=checkpoint,
|
||||
|
||||
@@ -46,6 +46,138 @@ from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Cache for folder path lookups to avoid redundant API calls
|
||||
# Maps folder_id -> (folder_name, parent_id)
|
||||
_folder_cache: dict[str, tuple[str, str | None]] = {}
|
||||
|
||||
|
||||
def _get_folder_info(
|
||||
service: GoogleDriveService, folder_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
"""Fetch folder name and parent ID, with caching."""
|
||||
if folder_id in _folder_cache:
|
||||
return _folder_cache[folder_id]
|
||||
|
||||
try:
|
||||
folder = (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=folder_id,
|
||||
fields="name, parents",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
folder_name = folder.get("name", "Unknown")
|
||||
parents = folder.get("parents", [])
|
||||
parent_id = parents[0] if parents else None
|
||||
_folder_cache[folder_id] = (folder_name, parent_id)
|
||||
return folder_name, parent_id
|
||||
except HttpError as e:
|
||||
logger.warning(f"Failed to get folder info for {folder_id}: {e}")
|
||||
_folder_cache[folder_id] = ("Unknown", None)
|
||||
return "Unknown", None
|
||||
|
||||
|
||||
def _get_drive_name(service: GoogleDriveService, drive_id: str) -> str:
|
||||
"""Fetch shared drive name."""
|
||||
cache_key = f"drive_{drive_id}"
|
||||
if cache_key in _folder_cache:
|
||||
return _folder_cache[cache_key][0]
|
||||
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id).execute()
|
||||
drive_name = drive.get("name", f"Shared Drive {drive_id}")
|
||||
_folder_cache[cache_key] = (drive_name, None)
|
||||
return drive_name
|
||||
except HttpError as e:
|
||||
logger.warning(f"Failed to get drive name for {drive_id}: {e}")
|
||||
_folder_cache[cache_key] = (f"Shared Drive {drive_id}", None)
|
||||
return f"Shared Drive {drive_id}"
|
||||
|
||||
|
||||
def build_folder_path(
|
||||
file: GoogleDriveFileType,
|
||||
service: GoogleDriveService,
|
||||
drive_id: str | None = None,
|
||||
user_email: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build the full folder path for a file by walking up the parent chain.
|
||||
Returns a list of folder names from root to immediate parent.
|
||||
|
||||
Args:
|
||||
file: The Google Drive file object
|
||||
service: Google Drive service instance
|
||||
drive_id: Optional drive ID (will be extracted from file if not provided)
|
||||
user_email: Optional user email to check ownership for "My Drive" vs "Shared with me"
|
||||
"""
|
||||
path_parts: list[str] = []
|
||||
|
||||
# Get drive_id from file if not provided
|
||||
if drive_id is None:
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
# Check if file is owned by the user (for distinguishing "My Drive" vs "Shared with me")
|
||||
is_owned_by_user = False
|
||||
if user_email:
|
||||
owners = file.get("owners", [])
|
||||
is_owned_by_user = any(
|
||||
owner.get("emailAddress", "").lower() == user_email.lower()
|
||||
for owner in owners
|
||||
)
|
||||
|
||||
# Get the file's parent folder ID
|
||||
parents = file.get("parents", [])
|
||||
if not parents:
|
||||
# File is at root level
|
||||
if drive_id:
|
||||
return [_get_drive_name(service, drive_id)]
|
||||
# If not in a shared drive, check if it's owned by the user
|
||||
if is_owned_by_user:
|
||||
return ["My Drive"]
|
||||
else:
|
||||
return ["Shared with me"]
|
||||
|
||||
parent_id: str | None = parents[0]
|
||||
|
||||
# Walk up the folder hierarchy (limit to 50 levels to prevent infinite loops)
|
||||
visited: set[str] = set()
|
||||
for _ in range(50):
|
||||
if not parent_id or parent_id in visited:
|
||||
break
|
||||
visited.add(parent_id)
|
||||
|
||||
folder_name, next_parent = _get_folder_info(service, parent_id)
|
||||
|
||||
# Check if we've reached the root (parent is the drive itself or no parent)
|
||||
if next_parent is None:
|
||||
# This folder's name is either the drive root, My Drive, or Shared with me
|
||||
if drive_id:
|
||||
path_parts.insert(0, _get_drive_name(service, drive_id))
|
||||
else:
|
||||
# Not in a shared drive - determine if it's "My Drive" or "Shared with me"
|
||||
if is_owned_by_user:
|
||||
path_parts.insert(0, "My Drive")
|
||||
else:
|
||||
path_parts.insert(0, "Shared with me")
|
||||
break
|
||||
else:
|
||||
path_parts.insert(0, folder_name)
|
||||
parent_id = next_parent
|
||||
|
||||
# If we didn't find a root, determine the root based on ownership and drive
|
||||
if not path_parts:
|
||||
if drive_id:
|
||||
return [_get_drive_name(service, drive_id)]
|
||||
elif is_owned_by_user:
|
||||
return ["My Drive"]
|
||||
else:
|
||||
return ["Shared with me"]
|
||||
|
||||
return path_parts
|
||||
|
||||
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
@@ -526,12 +658,33 @@ def _convert_drive_item_to_document(
|
||||
else None
|
||||
)
|
||||
|
||||
# Build doc_metadata with hierarchy information
|
||||
file_name = file.get("name", "")
|
||||
mime_type = file.get("mimeType", "")
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
# Build full folder path by walking up the parent chain
|
||||
# Pass retriever_email to determine if file is in "My Drive" vs "Shared with me"
|
||||
source_path = build_folder_path(
|
||||
file, _get_drive_service(), drive_id, retriever_email
|
||||
)
|
||||
|
||||
doc_metadata = {
|
||||
"hierarchy": {
|
||||
"source_path": source_path,
|
||||
"drive_id": drive_id,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
semantic_identifier=file_name,
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
"owner_names": ", ".join(
|
||||
owner.get("displayName", "") for owner in file.get("owners", [])
|
||||
|
||||
@@ -39,11 +39,11 @@ PERMISSION_FULL_DESCRIPTION = (
|
||||
"permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails)"
|
||||
)
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, "
|
||||
"nextPageToken, files(mimeType, id, name, driveId, parents, "
|
||||
"modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
FILE_FIELDS_WITH_PERMISSIONS = (
|
||||
f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, "
|
||||
f"nextPageToken, files(mimeType, id, name, driveId, parents, {PERMISSION_FULL_DESCRIPTION}, permissionIds, "
|
||||
"modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
@@ -65,7 +66,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
pass
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
documents: list[Document | HierarchyNode] = []
|
||||
|
||||
file_content_io = get_default_file_store().read_file(self.zip_path, mode="b")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -58,7 +59,7 @@ class GuruConnector(LoadConnector, PollConnector):
|
||||
if self.guru_user is None or self.guru_user_token is None:
|
||||
raise ConnectorMissingCredentialError("Guru")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
session = requests.Session()
|
||||
session.auth = (self.guru_user, self.guru_user_token)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -143,7 +144,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
|
||||
"""
|
||||
spots_to_process = self._fetch_spots_to_process()
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
try:
|
||||
for spot in spots_to_process:
|
||||
try:
|
||||
@@ -378,7 +379,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
|
||||
"""
|
||||
spots_to_process = self._fetch_spots_to_process()
|
||||
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
slim_doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
try:
|
||||
for spot in spots_to_process:
|
||||
try:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -414,7 +415,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
associations=["contacts", "companies", "deals"],
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
for ticket in tickets_iter:
|
||||
updated_at = ticket.updated_at.replace(tzinfo=None)
|
||||
@@ -490,6 +491,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=ticket.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Tickets"],
|
||||
"object_type": "ticket",
|
||||
"object_id": ticket.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -520,7 +528,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
associations=["contacts", "deals", "tickets"],
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
for company in companies_iter:
|
||||
updated_at = company.updated_at.replace(tzinfo=None)
|
||||
@@ -615,6 +623,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=company.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Companies"],
|
||||
"object_type": "company",
|
||||
"object_id": company.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -645,7 +660,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
associations=["contacts", "companies", "tickets"],
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
for deal in deals_iter:
|
||||
updated_at = deal.updated_at.replace(tzinfo=None)
|
||||
@@ -738,6 +753,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=deal.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Deals"],
|
||||
"object_type": "deal",
|
||||
"object_id": deal.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -770,7 +792,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
associations=["companies", "deals", "tickets"],
|
||||
)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
for contact in contacts_iter:
|
||||
updated_at = contact.updated_at.replace(tzinfo=None)
|
||||
@@ -881,6 +903,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=contact.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Contacts"],
|
||||
"object_type": "contact",
|
||||
"object_id": contact.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -13,14 +13,16 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
|
||||
# Output types that can include HierarchyNode alongside Documents/SlimDocuments
|
||||
GenerateDocumentsOutput = Iterator[list[Document | HierarchyNode]]
|
||||
GenerateSlimDocumentOutput = Iterator[list[SlimDocument | HierarchyNode]]
|
||||
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
@@ -239,7 +241,9 @@ class EventConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
|
||||
CheckpointOutput: TypeAlias = Generator[
|
||||
Document | HierarchyNode | ConnectorFailure, None, CT
|
||||
]
|
||||
|
||||
|
||||
class CheckpointedConnector(BaseConnector[CT]):
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -676,7 +677,7 @@ class JiraConnector(
|
||||
checkpoint_callback = make_checkpoint_callback(checkpoint)
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_doc_batch = []
|
||||
slim_doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in _perform_jql_search(
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -251,7 +252,7 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
logger.debug(f"Raw response from Linear: {response_json}")
|
||||
edges = response_json["data"]["issues"]["edges"]
|
||||
|
||||
documents: list[Document] = []
|
||||
documents: list[Document | HierarchyNode] = []
|
||||
for edge in edges:
|
||||
node = edge["node"]
|
||||
# Create sections for description and comments
|
||||
@@ -274,6 +275,10 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
# Cast the sections list to the expected type
|
||||
typed_sections = cast(list[TextSection | ImageSection], sections)
|
||||
|
||||
# Extract team name for hierarchy
|
||||
team_name = (node.get("team") or {}).get("name") or "Unknown Team"
|
||||
identifier = node.get("identifier", node["id"])
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
id=node["id"],
|
||||
@@ -282,6 +287,13 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
semantic_identifier=f"[{node['identifier']}] {node['title']}",
|
||||
title=node["title"],
|
||||
doc_updated_at=time_str_to_utc(node["updatedAt"]),
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [team_name],
|
||||
"team_name": team_name,
|
||||
"identifier": identifier,
|
||||
}
|
||||
},
|
||||
metadata={
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
|
||||
@@ -109,7 +110,7 @@ class LoopioConnector(LoadConnector, PollConnector):
|
||||
params: dict[str, str | int] = {"pageSize": self.batch_size}
|
||||
params["filter"] = json.dumps(filter)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
for library_entries in self._fetch_data(
|
||||
resource="v2/libraryEntries", params=params
|
||||
):
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import datetime
|
||||
import itertools
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -21,6 +20,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.mediawiki.family import family_class_dispatch
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -160,7 +160,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Request batches of pages from a MediaWiki site.
|
||||
|
||||
Args:
|
||||
@@ -170,7 +170,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
Yields:
|
||||
Lists of Documents containing each parsed page in a batch.
|
||||
"""
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
# Pywikibot can handle batching for us, including only loading page contents when we finally request them.
|
||||
category_pages = [
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.utils.text_processing import make_url_compatible
|
||||
|
||||
@@ -187,6 +188,10 @@ class DocumentBase(BaseModel):
|
||||
external_access: ExternalAccess | None = None
|
||||
doc_metadata: dict[str, Any] | None = None
|
||||
|
||||
# Parent hierarchy node raw ID - the folder/space/page containing this document
|
||||
# If None, document's hierarchy position is unknown or connector doesn't support hierarchy
|
||||
parent_hierarchy_raw_node_id: str | None = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
@@ -244,6 +249,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
@@ -365,6 +373,36 @@ class SlimDocument(BaseModel):
|
||||
external_access: ExternalAccess | None = None
|
||||
|
||||
|
||||
class HierarchyNode(BaseModel):
|
||||
"""
|
||||
Hierarchy node yielded by connectors.
|
||||
|
||||
This is the Pydantic model used by connectors, distinct from the
|
||||
SQLAlchemy HierarchyNode model in db/models.py. The connector runner
|
||||
layer converts this to the DB model when persisting to Postgres.
|
||||
"""
|
||||
|
||||
# Raw identifier from the source system
|
||||
# e.g., "1h7uWUR2BYZjtMfEXFt43tauj-Gp36DTPtwnsNuA665I" for Google Drive
|
||||
raw_node_id: str
|
||||
|
||||
# Raw ID of parent node, or None for SOURCE-level children (direct children of the source root)
|
||||
raw_parent_id: str | None = None
|
||||
|
||||
# Human-readable name for display
|
||||
display_name: str
|
||||
|
||||
# Link to view this node in the source system
|
||||
link: str | None = None
|
||||
|
||||
# What kind of structural node this is (folder, space, page, etc.)
|
||||
node_type: HierarchyNodeType
|
||||
|
||||
# Optional: if this hierarchy node represents a document (e.g., Confluence page),
|
||||
# this is the document ID. Set by the connector when the node IS a document.
|
||||
document_id: str | None = None
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
@@ -492,7 +493,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
def _read_pages(
|
||||
self,
|
||||
pages: list[NotionPage],
|
||||
) -> Generator[Document, None, None]:
|
||||
) -> Generator[Document | HierarchyNode, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents
|
||||
|
||||
Note that a page which is turned into a "wiki" becomes a database but both top level pages and top level databases
|
||||
@@ -624,7 +625,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
|
||||
def _recursive_load(self) -> Generator[list[Document], None, None]:
|
||||
def _recursive_load(self) -> GenerateDocumentsOutput:
|
||||
if self.root_page_id is None or not self.recursive_index_enabled:
|
||||
raise RuntimeError(
|
||||
"Recursive page lookup is not enabled, but we are trying to "
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.outline.client import OutlineApiClient
|
||||
from onyx.connectors.outline.client import OutlineClientRequestFailedError
|
||||
@@ -165,7 +166,7 @@ class OutlineConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
# Apply time filtering if specified
|
||||
filtered_batch = []
|
||||
filtered_batch: list[Document | HierarchyNode] = []
|
||||
for doc in doc_batch:
|
||||
if time_filter is None or time_filter(doc):
|
||||
filtered_batch.append(doc)
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -228,7 +229,7 @@ class ProductboardConnector(PollConnector):
|
||||
"Access token is not set up, was load_credentials called?"
|
||||
)
|
||||
|
||||
document_batch: list[Document] = []
|
||||
document_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
# NOTE: there is a concept of a "Note" in productboard, however
|
||||
# there is no read API for it atm. Additionally, comments are not
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
||||
@@ -443,7 +444,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield: list[Document | HierarchyNode] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
@@ -655,7 +656,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield: list[Document | HierarchyNode] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
@@ -1125,7 +1126,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
doc_metadata_list: list[SlimDocument | HierarchyNode] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.safe_query_all(query)
|
||||
@@ -1211,6 +1212,8 @@ if __name__ == "__main__":
|
||||
doc_count += len(doc_batch)
|
||||
print(f"doc_count: {doc_count}")
|
||||
for doc in doc_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
continue
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import ExternalAccess
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -1146,7 +1147,7 @@ class SharepointConnector(
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
|
||||
# goes over all urls, converts them into SlimDocument objects and then yields them in batches
|
||||
doc_batch: list[SlimDocument] = []
|
||||
doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for site_descriptor in site_descriptors:
|
||||
ctx: ClientContext | None = None
|
||||
|
||||
@@ -1708,7 +1709,9 @@ if __name__ == "__main__":
|
||||
|
||||
# Run the connector
|
||||
while checkpoint.has_more:
|
||||
for doc_batch, failure, next_checkpoint in runner.run(checkpoint):
|
||||
for doc_batch, hierarchy_node_batch, failure, next_checkpoint in runner.run(
|
||||
checkpoint
|
||||
):
|
||||
if doc_batch:
|
||||
print(f"Retrieved batch of {len(doc_batch)} documents")
|
||||
for doc in doc_batch:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -187,7 +188,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
if self.slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
@@ -245,7 +246,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
slim_doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
|
||||
@@ -52,6 +52,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.access import get_channel_access
|
||||
@@ -234,6 +235,8 @@ def thread_to_doc(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
channel_name = channel["name"]
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
@@ -247,7 +250,14 @@ def thread_to_doc(
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [channel_name],
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
}
|
||||
},
|
||||
metadata={"Channel": channel_name},
|
||||
external_access=channel_access,
|
||||
)
|
||||
|
||||
@@ -492,7 +502,7 @@ def _get_all_doc_ids(
|
||||
)
|
||||
|
||||
for message_batch in channel_message_batches:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
slim_doc_batch: list[SlimDocument | HierarchyNode] = []
|
||||
for message in message_batch:
|
||||
filter_reason = msg_filter_func(message)
|
||||
if filter_reason:
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.teams.models import Message
|
||||
@@ -301,7 +302,7 @@ class TeamsConnector(
|
||||
start=start,
|
||||
)
|
||||
|
||||
slim_doc_buffer = []
|
||||
slim_doc_buffer: list[SlimDocument | HierarchyNode] = []
|
||||
|
||||
for message in messages:
|
||||
slim_doc_buffer.append(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user