mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 15:02:43 +00:00
Compare commits
142 Commits
v2.10.7
...
feat/file-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
462537e2ee | ||
|
|
e3a9b2a021 | ||
|
|
7080b3d966 | ||
|
|
adc3c86b16 | ||
|
|
b110621b13 | ||
|
|
a2dc752d14 | ||
|
|
f7d47a6ca9 | ||
|
|
9cc71b71ee | ||
|
|
f2bafd113a | ||
|
|
bb00ebd4a8 | ||
|
|
fda04aa6d2 | ||
|
|
285aae6f2f | ||
|
|
b75b1019f3 | ||
|
|
bbba32b989 | ||
|
|
f06bf69956 | ||
|
|
7d4fe480cc | ||
|
|
7f5b512856 | ||
|
|
844a01f751 | ||
|
|
d64be385db | ||
|
|
d0518388d6 | ||
|
|
a7f6d5f535 | ||
|
|
059e2869e6 | ||
|
|
04d90fd496 | ||
|
|
7cd29f4892 | ||
|
|
c2b86efebf | ||
|
|
bc5835967e | ||
|
|
c2b11cae01 | ||
|
|
cf17ba6a1c | ||
|
|
b03634ecaa | ||
|
|
9a7e92464f | ||
|
|
09b2a69c82 | ||
|
|
c5c027c168 | ||
|
|
882163a4ea | ||
|
|
de83a9a6f0 | ||
|
|
f73ce0632f | ||
|
|
0b10b11af3 | ||
|
|
d9e3b657d0 | ||
|
|
f6e9928dc1 | ||
|
|
ca3179ad8d | ||
|
|
5529829ff5 | ||
|
|
bdc7f6c100 | ||
|
|
90f8656afa | ||
|
|
3c7d35a6e8 | ||
|
|
40d58a37e3 | ||
|
|
be3ecd9640 | ||
|
|
a6da511490 | ||
|
|
c7577ebe58 | ||
|
|
b87078a4f5 | ||
|
|
8a408e7023 | ||
|
|
4c7b73a355 | ||
|
|
8e9cb94d4f | ||
|
|
a21af4b906 | ||
|
|
7f0ce0531f | ||
|
|
b631bfa656 | ||
|
|
eca6b6bef2 | ||
|
|
51ef28305d | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b | ||
|
|
0594fd17de | ||
|
|
fded81dc28 | ||
|
|
31db112de9 | ||
|
|
a3e2da2c51 | ||
|
|
f4d33bcc0d | ||
|
|
464d957494 | ||
|
|
be12de9a44 | ||
|
|
3e4a1f8a09 | ||
|
|
af9b7826ab | ||
|
|
cb16eb13fc | ||
|
|
20a73bdd2e | ||
|
|
85cc2b99b7 | ||
|
|
1208a3ee2b | ||
|
|
900fcef9dd | ||
|
|
d4ed25753b | ||
|
|
0ee58333b4 | ||
|
|
11b7e0d571 | ||
|
|
a35831f328 | ||
|
|
048a6d5259 | ||
|
|
e4bdb15910 | ||
|
|
3517d59286 | ||
|
|
4bc08e5d88 | ||
|
|
4bd080cf62 | ||
|
|
b0a8625ffc | ||
|
|
f94baf6143 | ||
|
|
9e1867638a | ||
|
|
5b6d7c9f0d | ||
|
|
e5dcf31f10 | ||
|
|
8ca06ef3e7 | ||
|
|
6897dbd610 | ||
|
|
7f3cb77466 | ||
|
|
267042a5aa | ||
|
|
d02b3ae6ac | ||
|
|
683c3f7a7e | ||
|
|
008b4d2288 | ||
|
|
8be261405a | ||
|
|
61f2c48ebc | ||
|
|
dbde2e6d6d | ||
|
|
2860136214 | ||
|
|
49ec5994d3 | ||
|
|
8d5fb67f0f | ||
|
|
15d02f6e3c | ||
|
|
e58974c419 | ||
|
|
6b66c07952 | ||
|
|
cae058a3ac | ||
|
|
aa3b21a191 | ||
|
|
7a07a78696 | ||
|
|
a8db236e37 | ||
|
|
8a2e4ed36f | ||
|
|
216f2c95a7 | ||
|
|
67081efe08 | ||
|
|
9d40b8336f | ||
|
|
23f0033302 | ||
|
|
9011b76eb0 | ||
|
|
45e436bafc | ||
|
|
010bc36d61 | ||
|
|
468e488bdb | ||
|
|
9104c0ffce | ||
|
|
d36a6bd0b4 | ||
|
|
a3603c498c | ||
|
|
8f274e34c9 | ||
|
|
5c256760ff | ||
|
|
258e1372b3 | ||
|
|
83a543a265 | ||
|
|
f9719d199d | ||
|
|
1c7bb6e56a | ||
|
|
982ad7d329 | ||
|
|
f94292808b | ||
|
|
293553a2e2 | ||
|
|
ba906ae6fa | ||
|
|
c84c7a354e | ||
|
|
2187b0dd82 | ||
|
|
d88a417bf9 | ||
|
|
f2d32b0b3b |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -8,4 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
642
.github/workflows/deployment.yml
vendored
642
.github/workflows/deployment.yml
vendored
File diff suppressed because it is too large
Load Diff
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
1
.github/workflows/helm-chart-releases.yml
vendored
1
.github/workflows/helm-chart-releases.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
run: |
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -125,11 +128,13 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
cache \
|
||||
index \
|
||||
opensearch \
|
||||
code-interpreter
|
||||
|
||||
- name: Run migrations
|
||||
@@ -158,7 +163,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
8
.github/workflows/pr-helm-chart-testing.yml
vendored
8
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -88,6 +88,7 @@ jobs:
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
@@ -180,6 +181,11 @@ jobs:
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
# Note that opensearch.enabled is true whereas others in this install
|
||||
# are false. There is some work that needs to be done to get this
|
||||
# entire step working in CI, enabling opensearch here is a small step
|
||||
# in that direction. If this is causing issues, disabling it in this
|
||||
# step should be ok in the short term.
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
@@ -187,6 +193,8 @@ jobs:
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
|
||||
6
.github/workflows/pr-integration-tests.yml
vendored
6
.github/workflows/pr-integration-tests.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -214,7 +214,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
6
.github/workflows/pr-playwright-tests.yml
vendored
6
.github/workflows/pr-playwright-tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,5 +1,8 @@
|
||||
# editors
|
||||
.vscode
|
||||
!/.vscode/env_template.txt
|
||||
!/.vscode/launch.json
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
|
||||
|
||||
@@ -66,7 +66,8 @@ repos:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -74,6 +75,13 @@ repos:
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
name: Check for added large files
|
||||
args: ["--maxkb=1500"]
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
|
||||
155
.vscode/launch.template.jsonc → .vscode/launch.json
vendored
155
.vscode/launch.template.jsonc → .vscode/launch.json
vendored
@@ -1,5 +1,3 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
@@ -24,7 +22,7 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery background",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat"
|
||||
@@ -151,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -399,7 +415,6 @@
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
@@ -430,7 +445,6 @@
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
@@ -579,6 +593,137 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Build Sandbox Templates",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "onyx.server.features.build.sandbox.build_templates",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
},
|
||||
"consoleTitle": "Build Sandbox Templates"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "4",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--clean",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Create database snapshot",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"dump",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore database snapshot (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--clean",
|
||||
"--yes",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Upgrade database to head revision",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"upgrade"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
@@ -16,3 +16,8 @@ dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
# Craft: demo_data directory should be unzipped at container startup, not copied
|
||||
**/demo_data/
|
||||
# Craft: templates/outputs/venv is created at container startup
|
||||
**/templates/outputs/venv
|
||||
|
||||
@@ -37,10 +37,6 @@ CVE-2023-50868
|
||||
CVE-2023-52425
|
||||
CVE-2024-28757
|
||||
|
||||
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
|
||||
# No impact in our settings
|
||||
CVE-2023-7104
|
||||
|
||||
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
|
||||
# Accept the risk
|
||||
CVE-2023-25193
|
||||
|
||||
@@ -7,6 +7,10 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Build argument for Craft support (disabled by default)
|
||||
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
|
||||
ARG ENABLE_CRAFT=false
|
||||
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true" \
|
||||
@@ -46,7 +50,23 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Conditionally install Node.js 20 for Craft (required for Next.js)
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing Node.js 20 for Craft support..." && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# Conditionally install opencode CLI for Craft agent functionality
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
# TODO: download a specific, versioned release of the opencode CLI
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing opencode CLI for Craft support..." && \
|
||||
curl -fsSL https://opencode.ai/install | bash; \
|
||||
fi
|
||||
ENV PATH="/root/.opencode/bin:${PATH}"
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -91,8 +111,8 @@ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
@@ -119,7 +139,15 @@ COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
# This pre-bakes demo data, Python venv, and npm dependencies into the image
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Running Craft template setup at build time..." && \
|
||||
ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \
|
||||
fi
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
"""single onyx craft migration
|
||||
|
||||
Consolidates all buildmode/onyx craft tables into a single migration.
|
||||
|
||||
Tables created:
|
||||
- build_session: User build sessions with status tracking
|
||||
- sandbox: User-owned containerized environments (one per user)
|
||||
- artifact: Build output files (web apps, documents, images)
|
||||
- snapshot: Sandbox filesystem snapshots
|
||||
- build_message: Conversation messages for build sessions
|
||||
|
||||
Existing table modified:
|
||||
- connector_credential_pair: Added processing_mode column
|
||||
|
||||
Revision ID: 2020d417ec84
|
||||
Revises: 41fa44bef321
|
||||
Create Date: 2026-01-26 14:43:54.641405
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2020d417ec84"
|
||||
down_revision = "41fa44bef321"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# ENUMS
|
||||
# ==========================================================================
|
||||
|
||||
# Build session status enum
|
||||
build_session_status_enum = sa.Enum(
|
||||
"active",
|
||||
"idle",
|
||||
name="buildsessionstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Sandbox status enum
|
||||
sandbox_status_enum = sa.Enum(
|
||||
"provisioning",
|
||||
"running",
|
||||
"idle",
|
||||
"sleeping",
|
||||
"terminated",
|
||||
"failed",
|
||||
name="sandboxstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Artifact type enum
|
||||
artifact_type_enum = sa.Enum(
|
||||
"web_app",
|
||||
"pptx",
|
||||
"docx",
|
||||
"markdown",
|
||||
"excel",
|
||||
"image",
|
||||
name="artifacttype",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_session",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
build_session_status_enum,
|
||||
nullable=False,
|
||||
server_default="active",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("nextjs_port", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_session_user_created",
|
||||
"build_session",
|
||||
["user_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_build_session_status",
|
||||
"build_session",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE (user-owned, one per user)
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"sandbox",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("container_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sandbox_status_enum,
|
||||
nullable=False,
|
||||
server_default="provisioning",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_sandbox_status",
|
||||
"sandbox",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_sandbox_container_id",
|
||||
"sandbox",
|
||||
["container_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"artifact",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", artifact_type_enum, nullable=False),
|
||||
sa.Column("path", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_artifact_session_created",
|
||||
"artifact",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_artifact_type",
|
||||
"artifact",
|
||||
["type"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"snapshot",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("storage_path", sa.String(), nullable=False),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_snapshot_session_created",
|
||||
"snapshot",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"turn_index",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"SYSTEM",
|
||||
"USER",
|
||||
"ASSISTANT",
|
||||
"DANSWER",
|
||||
name="messagetype",
|
||||
create_type=False,
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"message_metadata",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_message_session_turn",
|
||||
"build_message",
|
||||
["session_id", "turn_index", sa.text("created_at ASC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"processing_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="regular",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_column("connector_credential_pair", "processing_mode")
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_message_session_turn", table_name="build_message")
|
||||
op.drop_table("build_message")
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
|
||||
op.drop_table("snapshot")
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_artifact_type", table_name="artifact")
|
||||
op.drop_index("ix_artifact_session_created", table_name="artifact")
|
||||
op.drop_table("artifact")
|
||||
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
|
||||
op.drop_index("ix_sandbox_status", table_name="sandbox")
|
||||
op.drop_table("sandbox")
|
||||
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_session_status", table_name="build_session")
|
||||
op.drop_index("ix_build_session_user_created", table_name="build_session")
|
||||
op.drop_table("build_session")
|
||||
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add_unique_constraint_to_inputprompt_prompt_user_id
|
||||
|
||||
Revision ID: 2c2430828bdf
|
||||
Revises: fb80bdd256de
|
||||
Create Date: 2026-01-20 16:01:54.314805
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2c2430828bdf"
|
||||
down_revision = "fb80bdd256de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique constraint on (prompt, user_id) for user-owned prompts
|
||||
# This ensures each user can only have one shortcut with a given name
|
||||
op.create_unique_constraint(
|
||||
"uq_inputprompt_prompt_user_id",
|
||||
"inputprompt",
|
||||
["prompt", "user_id"],
|
||||
)
|
||||
|
||||
# Create partial unique index for public prompts (where user_id IS NULL)
|
||||
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
|
||||
# so we need a partial index to ensure public prompt names are also unique
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
|
||||
ON inputprompt (prompt)
|
||||
WHERE user_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
|
||||
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove default prompt shortcuts
|
||||
|
||||
Revision ID: 41fa44bef321
|
||||
Revises: 2c2430828bdf
|
||||
Create Date: 2025-01-21
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41fa44bef321"
|
||||
down_revision = "2c2430828bdf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete any user associations for the default prompts first (foreign key constraint)
|
||||
op.execute(
|
||||
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
|
||||
)
|
||||
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
|
||||
op.execute("DELETE FROM inputprompt WHERE id < 0")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't restore the default prompts on downgrade
|
||||
pass
|
||||
@@ -0,0 +1,45 @@
|
||||
"""make processing mode default all caps
|
||||
|
||||
Revision ID: 72aa7de2e5cf
|
||||
Revises: 2020d417ec84
|
||||
Create Date: 2026-01-26 18:58:47.705253
|
||||
|
||||
This migration fixes the ProcessingMode enum value mismatch:
|
||||
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
|
||||
- The original migration stored lowercase VALUES ('regular', 'file_system')
|
||||
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
|
||||
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "72aa7de2e5cf"
|
||||
down_revision = "2020d417ec84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Convert existing lowercase values to uppercase to match enum member names
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
|
||||
"WHERE processing_mode = 'regular'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
|
||||
"WHERE processing_mode = 'file_system'"
|
||||
)
|
||||
|
||||
# Update the server default to use uppercase
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"processing_mode",
|
||||
server_default="REGULAR",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# State prior to this was broken, so we don't want to revert back to it
|
||||
pass
|
||||
349
backend/alembic/versions/81c22b1e2e78_hierarchy_nodes_v1.py
Normal file
349
backend/alembic/versions/81c22b1e2e78_hierarchy_nodes_v1.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""hierarchy_nodes_v1
|
||||
|
||||
Revision ID: 81c22b1e2e78
|
||||
Revises: 72aa7de2e5cf
|
||||
Create Date: 2026-01-13 18:10:01.021451
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "81c22b1e2e78"
|
||||
down_revision = "72aa7de2e5cf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Human-readable display names for each source
|
||||
SOURCE_DISPLAY_NAMES: dict[str, str] = {
|
||||
"ingestion_api": "Ingestion API",
|
||||
"slack": "Slack",
|
||||
"web": "Web",
|
||||
"google_drive": "Google Drive",
|
||||
"gmail": "Gmail",
|
||||
"requesttracker": "Request Tracker",
|
||||
"github": "GitHub",
|
||||
"gitbook": "GitBook",
|
||||
"gitlab": "GitLab",
|
||||
"guru": "Guru",
|
||||
"bookstack": "BookStack",
|
||||
"outline": "Outline",
|
||||
"confluence": "Confluence",
|
||||
"jira": "Jira",
|
||||
"slab": "Slab",
|
||||
"productboard": "Productboard",
|
||||
"file": "File",
|
||||
"coda": "Coda",
|
||||
"notion": "Notion",
|
||||
"zulip": "Zulip",
|
||||
"linear": "Linear",
|
||||
"hubspot": "HubSpot",
|
||||
"document360": "Document360",
|
||||
"gong": "Gong",
|
||||
"google_sites": "Google Sites",
|
||||
"zendesk": "Zendesk",
|
||||
"loopio": "Loopio",
|
||||
"dropbox": "Dropbox",
|
||||
"sharepoint": "SharePoint",
|
||||
"teams": "Teams",
|
||||
"salesforce": "Salesforce",
|
||||
"discourse": "Discourse",
|
||||
"axero": "Axero",
|
||||
"clickup": "ClickUp",
|
||||
"mediawiki": "MediaWiki",
|
||||
"wikipedia": "Wikipedia",
|
||||
"asana": "Asana",
|
||||
"s3": "S3",
|
||||
"r2": "R2",
|
||||
"google_cloud_storage": "Google Cloud Storage",
|
||||
"oci_storage": "OCI Storage",
|
||||
"xenforo": "XenForo",
|
||||
"not_applicable": "Not Applicable",
|
||||
"discord": "Discord",
|
||||
"freshdesk": "Freshdesk",
|
||||
"fireflies": "Fireflies",
|
||||
"egnyte": "Egnyte",
|
||||
"airtable": "Airtable",
|
||||
"highspot": "Highspot",
|
||||
"drupal_wiki": "Drupal Wiki",
|
||||
"imap": "IMAP",
|
||||
"bitbucket": "Bitbucket",
|
||||
"testrail": "TestRail",
|
||||
"mock_connector": "Mock Connector",
|
||||
"user_file": "User File",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Create hierarchy_node table
|
||||
op.create_table(
|
||||
"hierarchy_node",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("raw_node_id", sa.String(), nullable=False),
|
||||
sa.Column("display_name", sa.String(), nullable=False),
|
||||
sa.Column("link", sa.String(), nullable=True),
|
||||
sa.Column("source", sa.String(), nullable=False),
|
||||
sa.Column("node_type", sa.String(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
sa.Column("parent_id", sa.Integer(), nullable=True),
|
||||
# Permission fields - same pattern as Document table
|
||||
sa.Column(
|
||||
"external_user_emails",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"external_user_group_ids",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
# When document is deleted, just unlink (node can exist without document)
|
||||
sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"),
|
||||
# When parent node is deleted, orphan children (cleanup via pruning)
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL"
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"raw_node_id", "source", name="uq_hierarchy_node_raw_id_source"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"])
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"]
|
||||
)
|
||||
|
||||
# Add partial unique index to ensure only one SOURCE-type node per source
|
||||
# This prevents duplicate source root nodes from being created
|
||||
# NOTE: node_type stores enum NAME ('SOURCE'), not value ('source')
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type
|
||||
ON hierarchy_node (source)
|
||||
WHERE node_type = 'SOURCE'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create hierarchy_fetch_attempt table
|
||||
op.create_table(
|
||||
"hierarchy_fetch_attempt",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
|
||||
sa.Column("status", sa.String(), nullable=False),
|
||||
sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column("full_exception_trace", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("time_started", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_credential_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created",
|
||||
"hierarchy_fetch_attempt",
|
||||
["time_created"],
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair",
|
||||
"hierarchy_fetch_attempt",
|
||||
["connector_credential_pair_id"],
|
||||
)
|
||||
|
||||
# 3. Insert SOURCE-type hierarchy nodes for each DocumentSource
|
||||
# We insert these so every existing document can have a parent hierarchy node
|
||||
# NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'),
|
||||
# not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns.
|
||||
# SOURCE nodes are always public since they're just categorical roots.
|
||||
for source in DocumentSource:
|
||||
source_name = (
|
||||
source.name
|
||||
) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects
|
||||
source_value = source.value # e.g., 'google_drive' - the raw_node_id
|
||||
display_name = SOURCE_DISPLAY_NAMES.get(
|
||||
source_value, source_value.replace("_", " ").title()
|
||||
)
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public)
|
||||
VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true)
|
||||
ON CONFLICT (raw_node_id, source) DO NOTHING
|
||||
"""
|
||||
).bindparams(
|
||||
raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier)
|
||||
display_name=display_name,
|
||||
source=source_name, # Use .name for source column (SQLAlchemy enum storage)
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Add parent_hierarchy_node_id column to document table
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
# When hierarchy node is deleted, just unlink the document (SET NULL)
|
||||
op.create_foreign_key(
|
||||
"fk_document_parent_hierarchy_node",
|
||||
"document",
|
||||
"hierarchy_node",
|
||||
["parent_hierarchy_node_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
op.create_index(
|
||||
"ix_document_parent_hierarchy_node_id",
|
||||
"document",
|
||||
["parent_hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node
|
||||
# For documents with multiple connectors, we pick one source deterministically (MIN connector_id)
|
||||
# NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE')
|
||||
# because SQLAlchemy Enum(native_enum=False) uses the enum name for storage.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE document d
|
||||
SET parent_hierarchy_node_id = hn.id
|
||||
FROM (
|
||||
-- Get the source for each document (pick MIN connector_id for determinism)
|
||||
SELECT DISTINCT ON (dbcc.id)
|
||||
dbcc.id as doc_id,
|
||||
c.source as source
|
||||
FROM document_by_connector_credential_pair dbcc
|
||||
JOIN connector c ON dbcc.connector_id = c.id
|
||||
ORDER BY dbcc.id, dbcc.connector_id
|
||||
) doc_source
|
||||
JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE'
|
||||
WHERE d.id = doc_source.doc_id
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create the persona__hierarchy_node association table
|
||||
op.create_table(
|
||||
"persona__hierarchy_node",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups
|
||||
op.create_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
"persona__hierarchy_node",
|
||||
["hierarchy_node_id"],
|
||||
)
|
||||
|
||||
# Create the persona__document association table for attaching individual
|
||||
# documents directly to assistants
|
||||
op.create_table(
|
||||
"persona__document",
|
||||
sa.Column("persona_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("persona_id", "document_id"),
|
||||
)
|
||||
|
||||
# Add index for efficient lookups by document_id
|
||||
op.create_index(
|
||||
"ix_persona__document_document_id",
|
||||
"persona__document",
|
||||
["document_id"],
|
||||
)
|
||||
|
||||
# 6. Add last_time_hierarchy_fetch column to connector_credential_pair table
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove last_time_hierarchy_fetch from connector_credential_pair
|
||||
op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch")
|
||||
|
||||
# Drop persona__document table
|
||||
op.drop_index("ix_persona__document_document_id", table_name="persona__document")
|
||||
op.drop_table("persona__document")
|
||||
|
||||
# Drop persona__hierarchy_node table
|
||||
op.drop_index(
|
||||
"ix_persona__hierarchy_node_hierarchy_node_id",
|
||||
table_name="persona__hierarchy_node",
|
||||
)
|
||||
op.drop_table("persona__hierarchy_node")
|
||||
|
||||
# Remove parent_hierarchy_node_id from document
|
||||
op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document")
|
||||
op.drop_constraint(
|
||||
"fk_document_parent_hierarchy_node", "document", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("document", "parent_hierarchy_node_id")
|
||||
|
||||
# Drop hierarchy_fetch_attempt table
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt"
|
||||
)
|
||||
op.drop_table("hierarchy_fetch_attempt")
|
||||
|
||||
# Drop hierarchy_node table
|
||||
op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node")
|
||||
op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node")
|
||||
op.drop_table("hierarchy_node")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add chat_background to user
|
||||
|
||||
Revision ID: fb80bdd256de
|
||||
Revises: 8b5ce697290e
|
||||
Create Date: 2026-01-16 16:15:59.222617
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fb80bdd256de"
|
||||
down_revision = "8b5ce697290e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chat_background",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chat_background")
|
||||
@@ -122,6 +122,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
@@ -133,3 +136,9 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
# Used when MULTI_TENANT=false (self-hosted mode)
|
||||
CLOUD_DATA_PLANE_URL = os.environ.get(
|
||||
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
|
||||
)
|
||||
|
||||
73
backend/ee/onyx/configs/license_enforcement_config.py
Normal file
73
backend/ee/onyx/configs/license_enforcement_config.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Constants for license enforcement.
|
||||
|
||||
This file is the single source of truth for:
|
||||
1. Paths that bypass license enforcement (always accessible)
|
||||
2. Paths that require an EE license (EE-only features)
|
||||
|
||||
Import these constants in both production code and tests to ensure consistency.
|
||||
"""
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /billing - Unified billing API
|
||||
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
|
||||
# /tenants/billing-* - Legacy billing endpoints (backwards compatibility)
|
||||
# /manage/users, /users - User management (needed for seat limit resolution)
|
||||
# /notifications - Needed for UI to load properly
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
# Billing endpoints (unified API for both MT and self-hosted)
|
||||
"/billing",
|
||||
"/admin/billing",
|
||||
# Proxy endpoints for self-hosted billing (no tenant context)
|
||||
"/proxy",
|
||||
# Legacy tenant billing endpoints (kept for backwards compatibility)
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
# User management - needed to remove users when seat limit exceeded
|
||||
"/manage/users",
|
||||
"/manage/admin/users",
|
||||
"/manage/admin/valid-domains",
|
||||
"/manage/admin/deactivate-user",
|
||||
"/manage/admin/delete-user",
|
||||
"/users",
|
||||
# Notifications - needed for UI to load properly
|
||||
"/notifications",
|
||||
}
|
||||
)
|
||||
|
||||
# EE-only paths that require a valid license.
|
||||
# Users without a license (community edition) cannot access these.
|
||||
# These are blocked even when user has never subscribed (no license).
|
||||
EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset(
|
||||
{
|
||||
# User groups and access control
|
||||
"/manage/admin/user-group",
|
||||
# Analytics and reporting
|
||||
"/analytics",
|
||||
# Query history (admin chat session endpoints)
|
||||
"/admin/chat-sessions",
|
||||
"/admin/chat-session-history",
|
||||
"/admin/query-history",
|
||||
# Usage reporting/export
|
||||
"/admin/usage-report",
|
||||
# Standard answers (canned responses)
|
||||
"/manage/admin/standard-answer",
|
||||
# Token rate limits
|
||||
"/admin/token-rate-limits",
|
||||
# Evals
|
||||
"/evals",
|
||||
}
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import NamedTuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -23,6 +25,13 @@ LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
class SeatAvailabilityResult(NamedTuple):
|
||||
"""Result of a seat availability check."""
|
||||
|
||||
available: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -95,23 +104,30 @@ def delete_license(db_session: Session) -> bool:
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -276,3 +292,43 @@ def get_license_metadata(
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
|
||||
|
||||
def check_seat_availability(
|
||||
db_session: Session,
|
||||
seats_needed: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatAvailabilityResult:
|
||||
"""
|
||||
Check if there are enough seats available to add users.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
seats_needed: Number of seats needed (default 1)
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
SeatAvailabilityResult with available=True if seats are available,
|
||||
or available=False with error_message if limit would be exceeded.
|
||||
Returns available=True if no license exists (self-hosted = unlimited).
|
||||
"""
|
||||
metadata = get_license_metadata(db_session, tenant_id)
|
||||
|
||||
# No license = no enforcement (self-hosted without license)
|
||||
if metadata is None:
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
# Calculate current usage directly from DB (not cache) for accuracy
|
||||
current_used = get_used_seats(tenant_id)
|
||||
total_seats = metadata.seats
|
||||
|
||||
# Use > (not >=) to allow filling to exactly 100% capacity
|
||||
would_exceed_limit = current_used + seats_needed > total_seats
|
||||
if would_exceed_limit:
|
||||
return SeatAvailabilityResult(
|
||||
available=False,
|
||||
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
|
||||
f"cannot add {seats_needed} more user(s).",
|
||||
)
|
||||
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFun
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -60,6 +61,9 @@ def gmail_doc_sync(
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -195,7 +196,9 @@ def gdrive_doc_sync(
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
if isinstance(slim_doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if slim_doc.external_access is None:
|
||||
raise ValueError(
|
||||
f"Drive perm sync: No external access for document {slim_doc.id}"
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
@@ -111,6 +112,9 @@ def _get_slack_document_access(
|
||||
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if isinstance(doc_metadata, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if doc_metadata.external_access is None:
|
||||
raise ValueError(
|
||||
f"No external access for document {doc_metadata.id}. "
|
||||
|
||||
@@ -5,6 +5,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -49,6 +50,9 @@ def generic_doc_sync(
|
||||
callback.progress(label, 1)
|
||||
|
||||
for doc in doc_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
# TODO: handle hierarchynodes during sync
|
||||
continue
|
||||
if not doc.external_access:
|
||||
raise RuntimeError(
|
||||
f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}"
|
||||
|
||||
@@ -4,8 +4,10 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router
|
||||
from ee.onyx.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
@@ -85,10 +87,11 @@ def get_application() -> FastAPI:
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
else:
|
||||
# License enforcement middleware for self-hosted deployments only
|
||||
# Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart)
|
||||
# MT deployments use control plane gating via is_tenant_gated() instead
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
@@ -148,6 +151,13 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
@@ -17,7 +17,8 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.factory import get_current_primary_default_document_index
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
@@ -42,11 +43,13 @@ def _run_single_search(
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
@@ -72,7 +75,9 @@ def stream_search_query(
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
@@ -114,6 +119,7 @@ def stream_search_query(
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
@@ -121,7 +127,14 @@ def stream_search_query(
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(query, request.filters, document_index, user, db_session),
|
||||
(
|
||||
query,
|
||||
request.filters,
|
||||
document_index,
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
@@ -168,6 +181,9 @@ def stream_search_query(
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Truncate to the requested number of hits
|
||||
sections = sections[: request.num_hits]
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
|
||||
@@ -10,6 +10,16 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
("/admin/billing/stripe-publishable-key", {"GET"}),
|
||||
# Proxy endpoints use license-based auth, not user auth
|
||||
("/proxy/create-checkout-session", {"POST"}),
|
||||
("/proxy/claim-license", {"POST"}),
|
||||
("/proxy/create-customer-portal-session", {"POST"}),
|
||||
("/proxy/billing-information", {"GET"}),
|
||||
("/proxy/license/{tenant_id}", {"GET"}),
|
||||
("/proxy/seats/update", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
0
backend/ee/onyx/server/billing/__init__.py
Normal file
0
backend/ee/onyx/server/billing/__init__.py
Normal file
264
backend/ee/onyx/server/billing/api.py
Normal file
264
backend/ee/onyx/server/billing/api.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Unified Billing API endpoints.
|
||||
|
||||
These endpoints provide Stripe billing functionality for both cloud and
|
||||
self-hosted deployments. The service layer routes requests appropriately:
|
||||
|
||||
- Self-hosted: Routes through cloud data plane proxy
|
||||
Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Backend /admin/billing/* → Control plane
|
||||
|
||||
License claiming is handled separately by /license/claim endpoint (self-hosted only).
|
||||
|
||||
Migration Note (ENG-3533):
|
||||
This /admin/billing/* API replaces the older /tenants/* billing endpoints:
|
||||
- /tenants/billing-information -> /admin/billing/billing-information
|
||||
- /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session
|
||||
- /tenants/create-subscription-session -> /admin/billing/create-checkout-session
|
||||
- /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key
|
||||
|
||||
See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_customer_portal_session as create_portal_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import (
|
||||
get_billing_information as get_billing_service,
|
||||
)
|
||||
from ee.onyx.server.billing.service import update_seat_count as update_seat_service
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/admin/billing")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _get_license_data(db_session: Session) -> str | None:
|
||||
"""Get license data from database if exists (self-hosted only)."""
|
||||
if MULTI_TENANT:
|
||||
return None
|
||||
license_record = get_license(db_session)
|
||||
return license_record.license_data if license_record else None
|
||||
|
||||
|
||||
def _get_tenant_id() -> str | None:
|
||||
"""Get tenant ID for cloud deployments."""
|
||||
if MULTI_TENANT:
|
||||
return get_current_tenant_id()
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: CreateCheckoutSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session for new subscription or renewal.
|
||||
|
||||
For new customers, no license/tenant is required.
|
||||
For renewals, existing license (self-hosted) or tenant_id (cloud) is used.
|
||||
|
||||
After checkout completion:
|
||||
- Self-hosted: Use /license/claim to retrieve the license
|
||||
- Cloud: Subscription is automatically activated
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
email = request.email if request else None
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
request: CreateCustomerPortalSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session for managing subscription.
|
||||
|
||||
Requires existing license (self-hosted) or active tenant (cloud).
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def get_billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Get billing information for the current subscription.
|
||||
|
||||
Returns subscription status and details from Stripe.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted without license = no subscription
|
||||
if not MULTI_TENANT and not license_data:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
try:
|
||||
return await get_billing_service(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def update_seats(
|
||||
request: SeatUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Handles Stripe proration and license regeneration via control plane.
|
||||
"""
|
||||
license_data = _get_license_data(db_session)
|
||||
tenant_id = _get_tenant_id()
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
|
||||
try:
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""Fetch the Stripe publishable key.
|
||||
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
75
backend/ee/onyx/server/billing/models.py
Normal file
75
backend/ee/onyx/server/billing/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Pydantic models for the billing API."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe checkout session URL."""
|
||||
|
||||
stripe_checkout_url: str
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
"""Request to create a Stripe customer portal session."""
|
||||
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
"""Response containing the Stripe customer portal URL."""
|
||||
|
||||
stripe_customer_portal_url: str
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
"""Billing information for the current subscription."""
|
||||
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: datetime | None = None
|
||||
current_period_end: datetime | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: datetime | None = None
|
||||
trial_start: datetime | None = None
|
||||
trial_end: datetime | None = None
|
||||
payment_method_enabled: bool = False
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
"""Response when no subscription exists."""
|
||||
|
||||
subscribed: bool = False
|
||||
|
||||
|
||||
class SeatUpdateRequest(BaseModel):
|
||||
"""Request to update seat count."""
|
||||
|
||||
new_seat_count: int
|
||||
|
||||
|
||||
class SeatUpdateResponse(BaseModel):
|
||||
"""Response from seat update operation."""
|
||||
|
||||
success: bool
|
||||
current_seats: int
|
||||
used_seats: int
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
"""Response containing the Stripe publishable key."""
|
||||
|
||||
publishable_key: str
|
||||
267
backend/ee/onyx/server/billing/service.py
Normal file
267
backend/ee/onyx/server/billing/service.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Service layer for billing operations.
|
||||
|
||||
This module provides functions for billing operations that route differently
|
||||
based on deployment type:
|
||||
|
||||
- Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy
|
||||
Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane
|
||||
|
||||
- Cloud (MULTI_TENANT): Routes directly to control plane
|
||||
Flow: Cloud backend → Control plane
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.server.billing.models import BillingInformationResponse
|
||||
from ee.onyx.server.billing.models import CreateCheckoutSessionResponse
|
||||
from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# HTTP request timeout for billing service calls
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
Self-hosted instances authenticate with their license.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if license_data:
|
||||
headers["Authorization"] = f"Bearer {license_data}"
|
||||
return headers
|
||||
|
||||
|
||||
def _get_direct_headers() -> dict[str, str]:
|
||||
"""Build headers for direct control plane requests (cloud).
|
||||
|
||||
Cloud instances authenticate with JWT.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
}
|
||||
|
||||
|
||||
def _get_base_url() -> str:
|
||||
"""Get the base URL based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return CONTROL_PLANE_API_BASE_URL
|
||||
return f"{CLOUD_DATA_PLANE_URL}/proxy"
|
||||
|
||||
|
||||
def _get_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Get appropriate headers based on deployment type."""
|
||||
if MULTI_TENANT:
|
||||
return _get_direct_headers()
|
||||
return _get_proxy_headers(license_data)
|
||||
|
||||
|
||||
async def _make_billing_request(
|
||||
method: Literal["GET", "POST"],
|
||||
path: str,
|
||||
license_data: str | None = None,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
error_message: str = "Billing service request failed",
|
||||
) -> dict:
|
||||
"""Make an HTTP request to the billing service.
|
||||
|
||||
Consolidates the common HTTP request pattern used by all billing operations.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET or POST)
|
||||
path: URL path (appended to base URL)
|
||||
license_data: License for authentication (self-hosted)
|
||||
body: Request body for POST requests
|
||||
params: Query parameters for GET requests
|
||||
error_message: Default error message if request fails
|
||||
|
||||
Returns:
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
"""
|
||||
base_url = _get_base_url()
|
||||
url = f"{base_url}{path}"
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
detail = error_message
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
billing_period: str = "monthly",
|
||||
email: str | None = None,
|
||||
license_data: str | None = None,
|
||||
redirect_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Create a Stripe checkout session.
|
||||
|
||||
Args:
|
||||
billing_period: "monthly" or "annual"
|
||||
email: Customer email for new subscriptions
|
||||
license_data: Existing license for renewals (self-hosted)
|
||||
redirect_url: URL to redirect after successful checkout
|
||||
tenant_id: Tenant ID (cloud only, for renewals)
|
||||
|
||||
Returns:
|
||||
CreateCheckoutSessionResponse with checkout URL
|
||||
"""
|
||||
body: dict = {"billing_period": billing_period}
|
||||
if email:
|
||||
body["email"] = email
|
||||
if redirect_url:
|
||||
body["redirect_url"] = redirect_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-checkout-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create checkout session",
|
||||
)
|
||||
return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"])
|
||||
|
||||
|
||||
async def create_customer_portal_session(
|
||||
license_data: str | None = None,
|
||||
return_url: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Create a Stripe customer portal session.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
return_url: URL to return to after portal session
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
CreateCustomerPortalSessionResponse with portal URL
|
||||
"""
|
||||
body: dict = {}
|
||||
if return_url:
|
||||
body["return_url"] = return_url
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/create-customer-portal-session",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to create customer portal session",
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"])
|
||||
|
||||
|
||||
async def get_billing_information(
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> BillingInformationResponse | SubscriptionStatusResponse:
|
||||
"""Fetch billing information.
|
||||
|
||||
Args:
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
BillingInformationResponse or SubscriptionStatusResponse if no subscription
|
||||
"""
|
||||
params = {}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="GET",
|
||||
path="/billing-information",
|
||||
license_data=license_data,
|
||||
params=params or None,
|
||||
error_message="Failed to fetch billing information",
|
||||
)
|
||||
|
||||
# Check if no subscription
|
||||
if isinstance(data, dict) and data.get("subscribed") is False:
|
||||
return SubscriptionStatusResponse(subscribed=False)
|
||||
|
||||
return BillingInformationResponse(**data)
|
||||
|
||||
|
||||
async def update_seat_count(
|
||||
new_seat_count: int,
|
||||
license_data: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatUpdateResponse:
|
||||
"""Update the seat count for the current subscription.
|
||||
|
||||
Args:
|
||||
new_seat_count: New number of seats
|
||||
license_data: License blob for authentication (self-hosted)
|
||||
tenant_id: Tenant ID (cloud only)
|
||||
|
||||
Returns:
|
||||
SeatUpdateResponse with updated seat information
|
||||
"""
|
||||
body: dict = {"new_seat_count": new_seat_count}
|
||||
if tenant_id and MULTI_TENANT:
|
||||
body["tenant_id"] = tenant_id
|
||||
|
||||
data = await _make_billing_request(
|
||||
method="POST",
|
||||
path="/seats/update",
|
||||
license_data=license_data,
|
||||
body=body,
|
||||
error_message="Failed to update seat count",
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=data.get("success", False),
|
||||
current_seats=data.get("current_seats", 0),
|
||||
used_seats=data.get("used_seats", 0),
|
||||
message=data.get("message"),
|
||||
)
|
||||
@@ -1,4 +1,14 @@
|
||||
"""License API endpoints."""
|
||||
"""License API endpoints for self-hosted deployments.
|
||||
|
||||
These endpoints allow self-hosted Onyx instances to:
|
||||
1. Claim a license after Stripe checkout (via cloud data plane proxy)
|
||||
2. Upload a license file manually (for air-gapped deployments)
|
||||
3. View license status and seat usage
|
||||
4. Refresh/delete the local license
|
||||
|
||||
NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints.
|
||||
Cloud licensing is managed via the control plane and gated_tenants Redis key.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
@@ -9,6 +19,7 @@ from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
@@ -20,13 +31,11 @@ from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -79,81 +88,80 @@ async def get_seat_usage(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
@router.post("/claim")
|
||||
async def claim_license(
|
||||
session_id: str,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
Claim a license after Stripe checkout (self-hosted only).
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
After a user completes Stripe checkout, they're redirected back with a
|
||||
session_id. This endpoint exchanges that session_id for a signed license
|
||||
via the cloud data plane proxy.
|
||||
|
||||
Flow:
|
||||
1. Self-hosted frontend redirects to Stripe checkout (via cloud proxy)
|
||||
2. User completes payment
|
||||
3. Stripe redirects back to self-hosted instance with session_id
|
||||
4. Frontend calls this endpoint with session_id
|
||||
5. We call cloud data plane /proxy/claim-license to get the signed license
|
||||
6. License is stored locally and cached
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
# Call cloud data plane to claim the license
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license"
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"session_id": session_id},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
license_data = data.get("license")
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
# Store in DB
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
logger.info(
|
||||
f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}"
|
||||
)
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
detail = "Failed to claim license"
|
||||
try:
|
||||
error_data = e.response.json() if e.response is not None else {}
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -164,33 +172,36 @@ async def upload_license(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
Upload a license file manually (self-hosted only).
|
||||
|
||||
Used for air-gapped deployments where the cloud data plane is not accessible.
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
@@ -205,8 +216,10 @@ async def refresh_license_cache_endpoint(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Force refresh the license cache from the local database.
|
||||
|
||||
Useful after manual database changes or to verify license validity.
|
||||
Does NOT fetch from control plane - use /claim for that.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
@@ -233,9 +246,15 @@ async def delete_license(
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
|
||||
@@ -1,4 +1,42 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
"""Middleware to enforce license status for SELF-HOSTED deployments only.
|
||||
|
||||
NOTE: This middleware is NOT used for multi-tenant (cloud) deployments.
|
||||
Multi-tenant gating is handled separately by the control plane via the
|
||||
/tenants/product-gating endpoint and is_tenant_gated() checks.
|
||||
|
||||
IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED
|
||||
============================================================
|
||||
This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var.
|
||||
It works alongside the legacy ENTERPRISE_EDITION_ENABLED system:
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false (default):
|
||||
Middleware is disabled. EE features are controlled solely by
|
||||
ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior.
|
||||
|
||||
- LICENSE_ENFORCEMENT_ENABLED=true:
|
||||
Middleware actively enforces license status. EE features require
|
||||
a valid license, regardless of ENTERPRISE_EDITION_ENABLED.
|
||||
|
||||
Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license
|
||||
enforcement will be the only mechanism for gating EE features.
|
||||
|
||||
License Enforcement States (when enabled)
|
||||
=========================================
|
||||
For self-hosted deployments:
|
||||
|
||||
1. No license (never subscribed):
|
||||
- Allow community features (basic connectors, search, chat)
|
||||
- Block EE-only features (analytics, user groups, etc.)
|
||||
|
||||
2. GATED_ACCESS (fully expired):
|
||||
- Block all routes except billing/auth/license
|
||||
- User must renew subscription to continue
|
||||
|
||||
3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER):
|
||||
- Full access to all EE features
|
||||
- Seat limits enforced
|
||||
- GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
@@ -9,38 +47,30 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES
|
||||
from ee.onyx.configs.license_enforcement_config import (
|
||||
LICENSE_ENFORCEMENT_ALLOWED_PREFIXES,
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
return any(
|
||||
path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES
|
||||
)
|
||||
|
||||
|
||||
def _is_ee_only_path(path: str) -> bool:
|
||||
"""Check if path requires EE license (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
@@ -66,29 +96,84 @@ def add_license_enforcement_middleware(
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
|
||||
# If no cached metadata, check database (cache may have been cleared)
|
||||
if not metadata:
|
||||
logger.debug(
|
||||
"[license_enforcement] No cached license, checking database..."
|
||||
)
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
if metadata:
|
||||
logger.info(
|
||||
"[license_enforcement] Loaded license from database"
|
||||
)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"[license_enforcement] Failed to check database for license: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# User HAS a license (current or expired)
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
# License fully expired - gate the user
|
||||
# Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only,
|
||||
# they don't block access
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
else:
|
||||
# License is active - check seat limit
|
||||
# used_seats in cache is kept accurate via invalidation
|
||||
# when users are added/removed
|
||||
if metadata.used_seats > metadata.seats:
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request: "
|
||||
f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "seat_limit_exceeded",
|
||||
"message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.",
|
||||
"used_seats": metadata.used_seats,
|
||||
"seats": metadata.seats,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
# No license in cache OR database = never subscribed
|
||||
# Allow community features, but block EE-only features
|
||||
if _is_ee_only_path(path):
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking EE-only path (no license): {path}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "enterprise_license_required",
|
||||
"message": "This feature requires an Enterprise license. "
|
||||
"Please upgrade to access this functionality.",
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.debug(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
logger.info(
|
||||
f"[license_enforcement] Blocking request (license expired): {path}"
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
|
||||
@@ -32,6 +32,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@@ -12,21 +12,51 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
# Only GATED_ACCESS actually blocks access - other statuses are for notifications
|
||||
_BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS
|
||||
|
||||
|
||||
def check_ee_features_enabled() -> bool:
|
||||
"""EE version: checks if EE features should be available.
|
||||
|
||||
Returns True if:
|
||||
- LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode)
|
||||
- Cloud mode (MULTI_TENANT) - cloud handles its own gating
|
||||
- Self-hosted with a valid (non-expired) license
|
||||
|
||||
Returns False if:
|
||||
- Self-hosted with no license (never subscribed)
|
||||
- Self-hosted with expired license
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
# License enforcement disabled - allow EE features (legacy behavior)
|
||||
return True
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Cloud mode - EE features always available (gating handled by is_tenant_gated)
|
||||
return True
|
||||
|
||||
# Self-hosted with enforcement - check for valid license
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license for EE features: {e}")
|
||||
# Fail closed - if Redis is down, other things will break anyway
|
||||
return False
|
||||
|
||||
# No license or GATED_ACCESS - no EE features
|
||||
return False
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
if the license indicates GATED_ACCESS (fully expired).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
@@ -43,11 +73,10 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
if metadata and metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
# No license = user hasn't purchased yet, allow access for upgrade flow
|
||||
# GRACE_PERIOD/PAYMENT_REMINDER don't block - they're for notifications
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.proxy import router as proxy_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
@@ -22,3 +23,4 @@ router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
router.include_router(proxy_router)
|
||||
|
||||
@@ -1,3 +1,24 @@
|
||||
"""Billing API endpoints for cloud multi-tenant deployments.
|
||||
|
||||
DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/*
|
||||
which provides a unified API for both self-hosted and cloud deployments.
|
||||
|
||||
TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file.
|
||||
https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling
|
||||
|
||||
Current endpoints to migrate:
|
||||
- GET /tenants/billing-information -> GET /admin/billing/information
|
||||
- POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session
|
||||
- POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session
|
||||
- GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint)
|
||||
|
||||
Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls
|
||||
and are NOT part of this migration - they stay here.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@@ -12,11 +33,14 @@ from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -26,6 +50,10 @@ logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
@@ -80,11 +108,7 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
"""Create a Stripe customer portal session via the control plane."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
@@ -113,3 +137,67 @@ async def create_subscription_session(
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""
|
||||
Fetch the Stripe publishable key.
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -105,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
publishable_key: str
|
||||
|
||||
485
backend/ee/onyx/server/tenants/proxy.py
Normal file
485
backend/ee/onyx/server/tenants/proxy.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Proxy endpoints for billing operations.
|
||||
|
||||
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
|
||||
for self-hosted instances to reach the control plane.
|
||||
|
||||
Flow:
|
||||
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
|
||||
|
||||
Self-hosted instances call these endpoints with their license in the Authorization
|
||||
header. The cloud data plane validates the license signature and forwards the
|
||||
request to the control plane using JWT authentication.
|
||||
|
||||
Auth levels by endpoint:
|
||||
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
|
||||
- /claim-license: Session ID based (one-time after Stripe payment)
|
||||
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
|
||||
- /billing-information: Valid license required
|
||||
- /license/{tenant_id}: Valid license required
|
||||
- /seats/update: Valid license required
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import is_license_valid
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/proxy")
|
||||
|
||||
|
||||
def _check_license_enforcement_enabled() -> None:
|
||||
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Proxy endpoints are only available on cloud data plane",
|
||||
)
|
||||
|
||||
|
||||
def _extract_license_from_header(
|
||||
authorization: str | None,
|
||||
required: bool = True,
|
||||
) -> str | None:
|
||||
"""Extract license data from Authorization header.
|
||||
|
||||
Self-hosted instances authenticate to these proxy endpoints by sending their
|
||||
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
|
||||
|
||||
We use the Bearer scheme (RFC 6750) because:
|
||||
1. It's the standard HTTP auth scheme for token-based authentication
|
||||
2. The license blob is cryptographically signed (RSA), so it's self-validating
|
||||
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
|
||||
|
||||
The license data is the base64-encoded signed blob that contains tenant_id,
|
||||
seats, expiration, etc. We verify the signature to authenticate the caller.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value (e.g., "Bearer <license>")
|
||||
required: If True, raise 401 when header is missing/invalid
|
||||
|
||||
Returns:
|
||||
License data string (base64-encoded), or None if not required and missing
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if required and header is missing/invalid
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
if required:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Missing or invalid authorization header"
|
||||
)
|
||||
return None
|
||||
|
||||
return authorization.split(" ", 1)[1]
|
||||
|
||||
|
||||
def verify_license_auth(
|
||||
license_data: str,
|
||||
allow_expired: bool = False,
|
||||
) -> LicensePayload:
|
||||
"""Verify license signature and optionally check expiry.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded signed license blob
|
||||
allow_expired: If True, accept expired licenses (for renewal flows)
|
||||
|
||||
Returns:
|
||||
LicensePayload if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If license is invalid or expired (when not allowed)
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
|
||||
|
||||
if not allow_expired and not is_license_valid(payload):
|
||||
raise HTTPException(status_code=401, detail="License has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require valid (non-expired) license.
|
||||
|
||||
Used for endpoints that require an active subscription.
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=False)
|
||||
|
||||
|
||||
async def get_license_payload_allow_expired(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require license with valid signature, expired OK.
|
||||
|
||||
Used for endpoints needed to fix payment issues (portal, renewal checkout).
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def get_optional_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload | None:
|
||||
"""Dependency: Optional license auth (for checkout - new customers have none).
|
||||
|
||||
Returns None if no license provided, otherwise validates and returns payload.
|
||||
Expired licenses are allowed for renewal flows.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
license_data = _extract_license_from_header(authorization, required=False)
|
||||
if license_data is None:
|
||||
return None
|
||||
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def forward_to_control_plane(
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Forward a request to the control plane with proper authentication."""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
detail = "Control plane request failed"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"Control plane returned {status_code}: {detail}")
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
|
||||
"""Store license in database and update Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
license_data: Base64-encoded signed license blob
|
||||
"""
|
||||
try:
|
||||
# Verify before storing
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in database using the specific tenant's schema
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
# Update Redis cache
|
||||
update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license: {e}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store license")
|
||||
raise
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
# Redirect URL after successful checkout - self-hosted passes their instance URL
|
||||
redirect_url: str | None = None
|
||||
# Cancel URL when user exits checkout - returns to upgrade page
|
||||
cancel_url: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def proxy_create_checkout_session(
|
||||
request_body: CreateCheckoutSessionRequest,
|
||||
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Proxy checkout session creation to control plane.
|
||||
|
||||
Auth: Optional license (new customers don't have one yet).
|
||||
If license provided, expired is OK (for renewals).
|
||||
"""
|
||||
# license_payload is None for new customers who don't have a license yet.
|
||||
# In that case, tenant_id is omitted from the request body and the control
|
||||
# plane will create a new tenant during checkout completion.
|
||||
tenant_id = license_payload.tenant_id if license_payload else None
|
||||
|
||||
body: dict = {
|
||||
"billing_period": request_body.billing_period,
|
||||
}
|
||||
if tenant_id:
|
||||
body["tenant_id"] = tenant_id
|
||||
if request_body.email:
|
||||
body["email"] = request_body.email
|
||||
if request_body.redirect_url:
|
||||
body["redirect_url"] = request_body.redirect_url
|
||||
if request_body.cancel_url:
|
||||
body["cancel_url"] = request_body.cancel_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-checkout-session", body=body
|
||||
)
|
||||
return CreateCheckoutSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class ClaimLicenseRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ClaimLicenseResponse(BaseModel):
|
||||
tenant_id: str
|
||||
license: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@router.post("/claim-license")
|
||||
async def proxy_claim_license(
|
||||
request_body: ClaimLicenseRequest,
|
||||
) -> ClaimLicenseResponse:
|
||||
"""Claim a license after successful Stripe checkout.
|
||||
|
||||
Auth: Session ID based (one-time use after payment).
|
||||
The control plane verifies the session_id is valid and unclaimed.
|
||||
|
||||
Returns the license to the caller. For self-hosted instances, they will
|
||||
store the license locally. The cloud DP doesn't need to store it.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/claim-license",
|
||||
body={"session_id": request_body.session_id},
|
||||
)
|
||||
|
||||
tenant_id = result.get("tenant_id")
|
||||
license_data = result.get("license")
|
||||
|
||||
if not tenant_id or not license_data:
|
||||
logger.error(f"Control plane returned incomplete claim response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
return ClaimLicenseResponse(
|
||||
tenant_id=tenant_id,
|
||||
license=license_data,
|
||||
message="License claimed successfully",
|
||||
)
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def proxy_create_customer_portal_session(
|
||||
request_body: CreateCustomerPortalSessionRequest | None = None,
|
||||
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Proxy customer portal session creation to control plane.
|
||||
|
||||
Auth: License required, expired OK (need portal to fix payment issues).
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
body: dict = {"tenant_id": tenant_id}
|
||||
if request_body and request_body.return_url:
|
||||
body["return_url"] = request_body.return_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-customer-portal-session", body=body
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: str | None = None
|
||||
current_period_end: str | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: str | None = None
|
||||
trial_start: str | None = None
|
||||
trial_end: str | None = None
|
||||
payment_method_enabled: bool = False
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def proxy_billing_information(
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> BillingInformationResponse:
|
||||
"""Proxy billing information request to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"GET", "/billing-information", params={"tenant_id": tenant_id}
|
||||
)
|
||||
# Add tenant_id from license if not in response (control plane may not include it)
|
||||
if "tenant_id" not in result:
|
||||
result["tenant_id"] = tenant_id
|
||||
return BillingInformationResponse(**result)
|
||||
|
||||
|
||||
class LicenseFetchResponse(BaseModel):
|
||||
license: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
@router.get("/license/{tenant_id}")
|
||||
async def proxy_license_fetch(
|
||||
tenant_id: str,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> LicenseFetchResponse:
|
||||
"""Proxy license fetch to control plane.
|
||||
|
||||
Auth: Valid license required.
|
||||
The tenant_id in path must match the authenticated tenant.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
if tenant_id != license_payload.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot fetch license for a different tenant",
|
||||
)
|
||||
|
||||
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
|
||||
|
||||
# Auto-store the refreshed license
|
||||
license_data = result.get("license")
|
||||
if not license_data:
|
||||
logger.error(f"Control plane returned incomplete license response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
fetch_and_store_license(tenant_id, license_data)
|
||||
|
||||
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
async def proxy_seat_update(
|
||||
request_body: SeatUpdateRequest,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> SeatUpdateResponse:
|
||||
"""Proxy seat update to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
Handles Stripe proration and license regeneration.
|
||||
"""
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/seats/update",
|
||||
body={
|
||||
"tenant_id": tenant_id,
|
||||
"new_seat_count": request_body.new_seat_count,
|
||||
},
|
||||
)
|
||||
|
||||
return SeatUpdateResponse(
|
||||
success=result.get("success", False),
|
||||
current_seats=result.get("current_seats", 0),
|
||||
used_seats=result.get("used_seats", 0),
|
||||
message=result.get("message"),
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
@@ -47,6 +48,8 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
@@ -70,49 +73,104 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
|
||||
If a user already has an active mapping to a different tenant, they receive
|
||||
an inactive mapping (invitation) to this tenant. They can accept the
|
||||
invitation later to switch tenants.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if adding active users would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
|
||||
|
||||
unique_emails = set(emails)
|
||||
if not unique_emails:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
# Batch query 1: Get all existing mappings for these emails to this tenant
|
||||
# Lock rows to prevent concurrent modifications
|
||||
existing_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
emails_with_mapping = {m.email for m in existing_mappings}
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# Batch query 2: Get all active mappings for these emails (any tenant)
|
||||
active_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails_with_active_mapping = {m.email for m in active_mappings}
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
# Determine which users will consume a new seat.
|
||||
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
|
||||
# and don't consume seats until they accept. Only users without any active
|
||||
# mapping will get an ACTIVE mapping and consume a seat immediately.
|
||||
emails_consuming_seats = {
|
||||
email
|
||||
for email in unique_emails
|
||||
if email not in emails_with_mapping
|
||||
and email not in emails_with_active_mapping
|
||||
}
|
||||
|
||||
# Check seat availability inside the transaction to prevent race conditions.
|
||||
# Note: ALL users in unique_emails still get added below - this check only
|
||||
# validates we have capacity for users who will consume seats immediately.
|
||||
if emails_consuming_seats:
|
||||
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session,
|
||||
seats_needed=len(emails_consuming_seats),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# Add mappings for emails that don't already have one to this tenant
|
||||
for email in unique_emails:
|
||||
if email in emails_with_mapping:
|
||||
continue
|
||||
|
||||
# Create mapping: inactive if user belongs to another tenant (invitation),
|
||||
# active otherwise
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=email not in emails_with_active_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
except HTTPException:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
@@ -135,6 +193,9 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
@@ -149,6 +210,9 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -177,6 +241,9 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
@@ -195,19 +262,42 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if accepting would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
# Lock the user's mappings first to prevent race conditions.
|
||||
# This ensures no concurrent request can modify this user's mappings
|
||||
# while we check seats and activate.
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check seat availability within the same logical operation.
|
||||
# Note: This queries fresh data from DB, not cache.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session, seats_needed=1, tenant_id=tenant_id
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
@@ -237,6 +327,9 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
@@ -297,16 +390,41 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
Get the number of active users for this tenant.
|
||||
|
||||
A user counts toward the seat count if:
|
||||
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
|
||||
2. AND the User is active (User.is_active == True)
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
from onyx.db.models import User
|
||||
|
||||
# First get all emails with active mappings to this tenant
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
active_mapping_emails = (
|
||||
db_session.query(UserTenantMapping.email)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails = [email for (email,) in active_mapping_emails]
|
||||
|
||||
if not emails:
|
||||
return 0
|
||||
|
||||
# Now count how many of those users are actually active in the tenant's User table
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
@@ -19,21 +20,27 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
# Path to the license public key file
|
||||
_LICENSE_PUBLIC_KEY_PATH = (
|
||||
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
|
||||
)
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
"""Load the public key from file, with env var override."""
|
||||
# Allow env var override for flexibility
|
||||
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
|
||||
|
||||
if not key_pem:
|
||||
# Read from file
|
||||
if not _LICENSE_PUBLIC_KEY_PATH.exists():
|
||||
raise ValueError(
|
||||
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
|
||||
|
||||
key = serialization.load_pem_public_key(key_pem.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
@@ -53,17 +60,21 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
|
||||
# Parse into LicenseData to validate structure
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
|
||||
# not re-serialized through Pydantic. Pydantic may format fields differently
|
||||
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
|
||||
original_payload = decoded.get("payload", {})
|
||||
payload_json = json.dumps(original_payload, sort_keys=True)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
@@ -77,16 +88,18 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
logger.error("[verify_license] FAILED: Signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
logger.error(
|
||||
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
|
||||
)
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
logger.exception("[verify_license] FAILED: Unexpected error")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -20,7 +21,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
14
backend/keys/license_public_key.pem
Normal file
14
backend/keys/license_public_key.pem
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
|
||||
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
|
||||
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
|
||||
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
|
||||
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
|
||||
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
|
||||
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
|
||||
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
|
||||
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
|
||||
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
|
||||
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
|
||||
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
@@ -1456,6 +1457,9 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1463,13 +1467,19 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1498,6 +1508,13 @@ def get_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: Optional[str] = None,
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1514,6 +1531,9 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
@@ -1521,8 +1541,10 @@ def get_oauth_router(
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response: Response,
|
||||
redirect: bool = Query(False),
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
) -> Response | OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
@@ -1532,9 +1554,11 @@ def get_oauth_router(
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
@@ -1551,6 +1575,31 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@@ -1600,7 +1649,33 @@ def get_oauth_router(
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
98
backend/onyx/background/README.md
Normal file
98
backend/onyx/background/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Overview of Onyx Background Jobs
|
||||
|
||||
The background jobs take care of:
|
||||
1. Pulling/Indexing documents (from connectors)
|
||||
2. Updating document metadata (from connectors)
|
||||
3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata)
|
||||
4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat)
|
||||
5. Reporting metrics on things like queue length for monitoring purposes
|
||||
|
||||
## Worker → Queue Mapping
|
||||
|
||||
| Worker | File | Queues |
|
||||
|--------|------|--------|
|
||||
| Primary | `apps/primary.py` | `celery` |
|
||||
| Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` |
|
||||
| Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` |
|
||||
| Docprocessing | `apps/docprocessing.py` | `docprocessing` |
|
||||
| Docfetching | `apps/docfetching.py` | `connector_doc_fetching` |
|
||||
| User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` |
|
||||
| Monitoring | `apps/monitoring.py` | `monitoring` |
|
||||
| Background (consolidated) | `apps/background.py` | All queues above except `celery` |
|
||||
|
||||
## Non-Worker Apps
|
||||
| App | File | Purpose |
|
||||
|-----|------|---------|
|
||||
| **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules |
|
||||
| **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) |
|
||||
|
||||
### Shared Module
|
||||
`app_base.py` provides:
|
||||
- `TenantAwareTask` - Base task class that sets tenant context
|
||||
- Signal handlers for logging, cleanup, and lifecycle events
|
||||
- Readiness probes and health checks
|
||||
|
||||
|
||||
## Worker Details
|
||||
|
||||
### Primary (Coordinator and task dispatcher)
|
||||
It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock
|
||||
which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps)
|
||||
|
||||
On startup:
|
||||
- waits for redis, postgres, document index to all be healthy
|
||||
- acquires the singleton lock
|
||||
- cleans all the redis states associated with background jobs
|
||||
- mark orphaned index attempts failed
|
||||
|
||||
Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
|
||||
| Task | Frequency | Description |
|
||||
|------|-----------|-------------|
|
||||
| `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue |
|
||||
| `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue |
|
||||
| `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue |
|
||||
| `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue |
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive.
|
||||
See supervisord.conf for watchdog config.
|
||||
|
||||
|
||||
### Light
|
||||
Fast and short living tasks that are not resource intensive. High concurrency:
|
||||
Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once.
|
||||
|
||||
Tasks it handles:
|
||||
- Syncs access/permissions, document sets, boosts, hidden state
|
||||
- Deletes documents that are marked for deletion in Postgres
|
||||
- Cleanup of checkpoints and index attempts
|
||||
|
||||
|
||||
### Heavy
|
||||
Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch.
|
||||
|
||||
Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc.
|
||||
|
||||
Generates CSV exports which may take a long time with significant data in Postgres.
|
||||
|
||||
Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files
|
||||
|
||||
|
||||
### Docprocessing, Docfetching, User File Processing
|
||||
Docprocessing and Docfetching are for indexing documents:
|
||||
- Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks
|
||||
- Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index
|
||||
User Files come from uploads directly via the input bar
|
||||
|
||||
|
||||
### Monitoring
|
||||
Observability and metrics collections:
|
||||
- Queue lengths, connector success/failure, lconnector latencies
|
||||
- Memory of supervisor managed processes (workers, beat, slack)
|
||||
- Cloud and multitenant specific monitorings
|
||||
@@ -26,10 +26,13 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.opensearch.client import (
|
||||
wait_for_opensearch_with_timeout,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -40,6 +43,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.tracing.braintrust_tracing import setup_braintrust_if_creds_available
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -234,6 +238,9 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||
)
|
||||
|
||||
# Initialize Braintrust tracing in workers if credentials are available.
|
||||
setup_braintrust_if_creds_available()
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
@@ -516,15 +523,17 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
# TODO(andrei): Do some similar liveness checking for OpenSearch.
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
@@ -121,7 +121,6 @@ celery_app.autodiscover_tasks(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
@@ -134,5 +133,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -98,5 +98,8 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.kg_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME)
|
||||
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
]
|
||||
)
|
||||
@@ -116,5 +116,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -318,12 +318,12 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.interfaces import SlimConnectorWithPermSync
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -32,10 +33,16 @@ PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: Iterator[list[Document]] | Iterator[list[SlimDocument]],
|
||||
doc_batch: (
|
||||
Iterator[list[Document | HierarchyNode]]
|
||||
| Iterator[list[SlimDocument | HierarchyNode]]
|
||||
),
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
yield {
|
||||
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
|
||||
for doc in doc_list
|
||||
}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_KG_PROCESSING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
@@ -57,24 +57,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing-clustering-only",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
"schedule": timedelta(seconds=600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
@@ -129,6 +111,15 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-hierarchy-fetching",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
|
||||
"schedule": timedelta(hours=1), # Check hourly, but only fetch once per day
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
@@ -139,6 +130,27 @@ beat_task_templates: list[dict] = [
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
# Sandbox cleanup tasks
|
||||
{
|
||||
"name": "cleanup-idle-sandboxes",
|
||||
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cleanup-old-snapshots",
|
||||
"task": OnyxCeleryTask.CLEANUP_OLD_SNAPSHOTS,
|
||||
"schedule": timedelta(hours=24),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
|
||||
@@ -87,7 +87,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
@@ -1436,7 +1436,7 @@ def _docprocessing_task(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
document_indices = get_all_document_indices(
|
||||
index_attempt.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -1473,7 +1473,7 @@ def _docprocessing_task(
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True, # Documents are already filtered during extraction
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
371
backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py
Normal file
371
backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""Celery tasks for hierarchy fetching.
|
||||
|
||||
This module provides tasks for fetching hierarchy node information from connectors.
|
||||
Hierarchy nodes represent structural elements like folders, spaces, and pages that
|
||||
can be used to filter search results.
|
||||
|
||||
The hierarchy fetching pipeline runs once per day per connector and fetches
|
||||
structural information from the connector source.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import HierarchyConnector
|
||||
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
from onyx.db.connector import mark_cc_pair_as_hierarchy_fetched
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Hierarchy fetching runs once per day (24 hours in seconds)
|
||||
HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60
|
||||
|
||||
|
||||
def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if hierarchy fetching is due for this connector.
|
||||
|
||||
Hierarchy fetching should run once per day for active connectors.
|
||||
"""
|
||||
# Skip if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
return False
|
||||
|
||||
# Skip if connector has never successfully indexed
|
||||
if not cc_pair.last_successful_index_time:
|
||||
return False
|
||||
|
||||
# Check if we've fetched hierarchy recently
|
||||
last_fetch = cc_pair.last_time_hierarchy_fetch
|
||||
if last_fetch is None:
|
||||
# Never fetched before - fetch now
|
||||
return True
|
||||
|
||||
# Check if enough time has passed since last fetch
|
||||
next_fetch_time = last_fetch + timedelta(seconds=HIERARCHY_FETCH_INTERVAL_SECONDS)
|
||||
return datetime.now(timezone.utc) >= next_fetch_time
|
||||
|
||||
|
||||
def _try_creating_hierarchy_fetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Try to create a hierarchy fetching task for a connector.
|
||||
|
||||
Returns the task ID if created, None otherwise.
|
||||
"""
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# Serialize task creation attempts
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + f"hierarchy_fetching_{cc_pair.id}",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Refresh to get latest state
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate task ID
|
||||
custom_task_id = f"hierarchy_fetching_{cc_pair.id}_{uuid4()}"
|
||||
|
||||
# Send the task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
if not result:
|
||||
raise RuntimeError("send_task for hierarchy_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created hierarchy fetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return custom_task_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Failed to create hierarchy fetching task: cc_pair={cc_pair.id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""Check for connectors that need hierarchy fetching and spawn tasks.
|
||||
|
||||
This task runs periodically (once per day) and checks all active connectors
|
||||
to see if they need hierarchy information fetched.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
task_logger.info("check_for_hierarchy_fetching - Starting")
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_HIERARCHY_FETCHING_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
locked = True
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all active connector credential pairs
|
||||
cc_pair_ids = fetch_indexable_standard_connector_credential_pair_ids(
|
||||
db_session=db_session,
|
||||
active_cc_pairs_only=True,
|
||||
)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair or not _is_hierarchy_fetching_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_id = _try_creating_hierarchy_fetching_task(
|
||||
celery_app=self.app,
|
||||
cc_pair=cc_pair,
|
||||
db_session=db_session,
|
||||
r=redis_client,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
if task_id:
|
||||
tasks_created += 1
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("check_for_hierarchy_fetching - Unexpected error")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_hierarchy_fetching - Lock not owned on completion"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_hierarchy_fetching finished: "
|
||||
f"tasks_created={tasks_created} elapsed={time_elapsed:.2f}s"
|
||||
)
|
||||
return tasks_created
|
||||
|
||||
|
||||
# Batch size for hierarchy node processing
|
||||
HIERARCHY_NODE_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _run_hierarchy_extraction(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
source: DocumentSource,
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
Run the hierarchy extraction for a connector.
|
||||
|
||||
Instantiates the connector and calls load_hierarchy() if the connector
|
||||
implements HierarchyConnector.
|
||||
|
||||
Returns the total number of hierarchy nodes extracted.
|
||||
"""
|
||||
connector = cc_pair.connector
|
||||
credential = cc_pair.credential
|
||||
|
||||
# Instantiate the connector using its configured input type
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
# Check if the connector supports hierarchy fetching
|
||||
if not isinstance(runnable_connector, HierarchyConnector):
|
||||
task_logger.debug(
|
||||
f"Connector {source} does not implement HierarchyConnector, skipping"
|
||||
)
|
||||
return 0
|
||||
|
||||
# Determine time range: start from last hierarchy fetch, end at now
|
||||
last_fetch = cc_pair.last_time_hierarchy_fetch
|
||||
start_time = last_fetch.timestamp() if last_fetch else 0
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
total_nodes = 0
|
||||
node_batch: list[PydanticHierarchyNode] = []
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
def _process_batch() -> int:
|
||||
"""Process accumulated hierarchy nodes batch."""
|
||||
if not node_batch:
|
||||
return 0
|
||||
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=node_batch,
|
||||
source=source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
count = len(node_batch)
|
||||
node_batch.clear()
|
||||
return count
|
||||
|
||||
# Fetch hierarchy nodes from the connector
|
||||
for node in runnable_connector.load_hierarchy(start=start_time, end=end_time):
|
||||
node_batch.append(node)
|
||||
if len(node_batch) >= HIERARCHY_NODE_BATCH_SIZE:
|
||||
total_nodes += _process_batch()
|
||||
|
||||
# Process any remaining nodes
|
||||
total_nodes += _process_batch()
|
||||
|
||||
return total_nodes
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK,
|
||||
soft_time_limit=3600, # 1 hour soft limit
|
||||
time_limit=3900, # 1 hour 5 min hard limit
|
||||
bind=True,
|
||||
)
|
||||
def connector_hierarchy_fetching_task(
|
||||
self: Task,
|
||||
*,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""Fetch hierarchy information from a connector.
|
||||
|
||||
This task fetches structural information (folders, spaces, pages, etc.)
|
||||
from the connector source and stores it in the database.
|
||||
"""
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task starting: "
|
||||
f"cc_pair={cc_pair_id} tenant={tenant_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"CC pair not found for hierarchy fetching: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.info(
|
||||
f"Skipping hierarchy fetching for deleting connector: "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
source = cc_pair.connector.source
|
||||
total_nodes = _run_hierarchy_extraction(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
source=source,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task: "
|
||||
f"Extracted {total_nodes} hierarchy nodes for cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
# Update the last fetch time to prevent re-running until next interval
|
||||
mark_cc_pair_as_hierarchy_fetched(db_session, cc_pair_id)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"connector_hierarchy_fetching_task failed: cc_pair={cc_pair_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
task_logger.info(
|
||||
f"connector_hierarchy_fetching_task completed: cc_pair={cc_pair_id}"
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.apps.client import celery_app
|
||||
from onyx.background.celery.tasks.kg_processing.utils import is_kg_processing_blocked
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_processing_requirements_met,
|
||||
)
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
|
||||
def try_creating_kg_processing_task(
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Schedules the KG Processing for a tenant immediately. Will not schedule if
|
||||
the tenant is not ready for KG processing.
|
||||
"""
|
||||
|
||||
try:
|
||||
if not is_kg_processing_requirements_met():
|
||||
return False
|
||||
|
||||
# Send the KG processing task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.KG_PROCESSING,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
if not result:
|
||||
task_logger.error("send_task for kg processing failed.")
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_kg_processing_task - Unexpected exception for tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def try_creating_kg_source_reset_task(
|
||||
tenant_id: str,
|
||||
source_name: str | None,
|
||||
index_name: str,
|
||||
) -> bool:
|
||||
"""Schedules the KG Source Reset for a tenant immediately. Will not do anything if
|
||||
the tenant is currently being processed.
|
||||
"""
|
||||
|
||||
try:
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
# Send the KG source reset task
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
source_name=source_name,
|
||||
index_name=index_name,
|
||||
),
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
if not result:
|
||||
task_logger.error("send_task for kg source reset failed.")
|
||||
return bool(result)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_kg_source_reset_task - Unexpected exception for tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
@@ -1,253 +0,0 @@
|
||||
import time
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_clustering_only_requirements_met,
|
||||
)
|
||||
from onyx.background.celery.tasks.kg_processing.utils import (
|
||||
is_kg_processing_requirements_met,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.kg.clustering.clustering import kg_clustering
|
||||
from onyx.kg.extractions.extraction_processing import kg_extraction
|
||||
from onyx.kg.resets.reset_source import reset_source_kg_index
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""a lightweight task used to kick off kg processing tasks."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing - Starting")
|
||||
try:
|
||||
if not is_kg_processing_requirements_met():
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG processing for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_PROCESSING,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg processing check")
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"check_for_kg_processing finished: elapsed={time_elapsed:.2f}")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_KG_PROCESSING_CLUSTERING_ONLY,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_kg_processing_clustering_only(self: Task, *, tenant_id: str) -> None:
|
||||
"""a lightweight task used to kick off kg clustering tasks."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
task_logger.warning("check_for_kg_processing_clustering_only - Starting")
|
||||
|
||||
try:
|
||||
if not is_kg_clustering_only_requirements_met():
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Found documents needing KG clustering for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
kwargs={
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.KG_PROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during kg clustering-only check")
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"check_for_kg_processing_clustering_only finished: elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_PROCESSING,
|
||||
bind=True,
|
||||
)
|
||||
def kg_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""a task for doing kg extraction and clustering."""
|
||||
|
||||
task_logger.warning(f"kg_processing - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
task_logger.info(f"KG processing set to in progress for tenant {tenant_id}")
|
||||
|
||||
kg_extraction(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg processing")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_processing - Lock not owned on completion: " f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg processing task!")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_CLUSTERING_ONLY,
|
||||
bind=True,
|
||||
)
|
||||
def kg_clustering_only(self: Task, *, tenant_id: str) -> None:
|
||||
"""a task for doing kg clustering only."""
|
||||
|
||||
task_logger.warning(f"kg_clustering_only - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
task_logger.info(
|
||||
f"KG clustering-only set to in progress for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
kg_clustering(
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_str,
|
||||
lock=lock_beat,
|
||||
processing_chunk_batch_size=8,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg clustering-only")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_clustering_only - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg clustering-only task!")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KG_RESET_SOURCE_INDEX,
|
||||
bind=True,
|
||||
)
|
||||
def kg_reset_source_index(
|
||||
self: Task, *, tenant_id: str, source_name: str, index_name: str
|
||||
) -> None:
|
||||
"""a task for KG reset of a source."""
|
||||
|
||||
task_logger.warning(f"kg_reset_source_index - Starting for tenant {tenant_id}")
|
||||
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.KG_PROCESSING_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
try:
|
||||
reset_source_kg_index(
|
||||
source_name=source_name,
|
||||
tenant_id=tenant_id,
|
||||
index_name=index_name,
|
||||
lock=lock_beat,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Error during kg reset")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"kg_reset_source_index - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
task_logger.debug("Completed kg reset task!")
|
||||
@@ -1,78 +0,0 @@
|
||||
import time
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import check_for_documents_needing_kg_processing
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
from onyx.db.models import KGRelationshipExtractionStaging
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
def is_kg_processing_blocked() -> bool:
|
||||
"""Checks if there are any KG tasks in progress."""
|
||||
redis_client = get_redis_client()
|
||||
lock_beat: RedisLock = redis_client.lock(OnyxRedisLocks.KG_PROCESSING_LOCK)
|
||||
return lock_beat.locked()
|
||||
|
||||
|
||||
def is_kg_processing_requirements_met() -> bool:
|
||||
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
|
||||
and there are documents that need KG processing."""
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
kg_config = get_kg_config_settings()
|
||||
if not is_kg_config_settings_enabled_valid(kg_config):
|
||||
return False
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
has_staging_entities = (
|
||||
db_session.query(KGEntityExtractionStaging).first() is not None
|
||||
)
|
||||
has_staging_relationships = (
|
||||
db_session.query(KGRelationshipExtractionStaging).first() is not None
|
||||
)
|
||||
return (
|
||||
check_for_documents_needing_kg_processing(
|
||||
db_session,
|
||||
kg_config.KG_COVERAGE_START_DATE,
|
||||
kg_config.KG_MAX_COVERAGE_DAYS,
|
||||
)
|
||||
or has_staging_entities
|
||||
or has_staging_relationships
|
||||
)
|
||||
|
||||
|
||||
def is_kg_clustering_only_requirements_met() -> bool:
|
||||
"""Checks that there are no other KG tasks in progress, KG is enabled, valid,
|
||||
and there are documents that need KG clustering."""
|
||||
if is_kg_processing_blocked():
|
||||
return False
|
||||
|
||||
kg_config = get_kg_config_settings()
|
||||
if not is_kg_config_settings_enabled_valid(kg_config):
|
||||
return False
|
||||
|
||||
# Check if there are any entries in the staging tables
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
has_staging_entities = (
|
||||
db_session.query(KGEntityExtractionStaging).first() is not None
|
||||
)
|
||||
has_staging_relationships = (
|
||||
db_session.query(KGRelationshipExtractionStaging).first() is not None
|
||||
)
|
||||
|
||||
return has_staging_entities or has_staging_relationships
|
||||
|
||||
|
||||
def extend_lock(lock: RedisLock, timeout: int, last_lock_time: float) -> float:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (timeout / 4):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
return last_lock_time
|
||||
@@ -25,7 +25,7 @@ from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import delete_document_references_from_kg
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -97,13 +97,17 @@ def document_by_cc_pair_cleanup_task(
|
||||
action = "skip"
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates and deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
@@ -113,11 +117,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
_ = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
_ = retry_document_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_document_references_from_kg(
|
||||
db_session=db_session,
|
||||
@@ -155,14 +160,18 @@ def document_by_cc_pair_cleanup_task(
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -27,12 +27,13 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -232,7 +233,9 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
|
||||
try:
|
||||
for batch in connector.load_from_state():
|
||||
documents.extend(batch)
|
||||
documents.extend(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(
|
||||
tenant_id=tenant_id,
|
||||
@@ -244,7 +247,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -258,7 +262,7 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
@@ -412,12 +416,16 @@ def process_single_user_file_delete(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
@@ -438,11 +446,12 @@ def process_single_user_file_delete(
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
retry_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -564,12 +573,16 @@ def process_single_user_file_project_sync(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
@@ -579,13 +592,14 @@ def process_single_user_file_project_sync(
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -49,7 +49,7 @@ from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -70,6 +70,8 @@ logger = setup_logger()
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
# TODO(andrei): Rename all these kinds of functions from *vespa* to a more
|
||||
# generic *document_index*.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
@@ -465,13 +467,17 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
@@ -500,14 +506,18 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Factory stub for running celery worker for knowledge graph processing.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.kg_processing import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -31,17 +31,21 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
@@ -53,7 +57,15 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -367,6 +379,7 @@ def connector_document_extraction(
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
processing_mode = index_attempt.connector_credential_pair.processing_mode
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
@@ -534,9 +547,12 @@ def connector_document_extraction(
|
||||
logger.info(
|
||||
f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
for (
|
||||
document_batch,
|
||||
hierarchy_node_batch,
|
||||
failure,
|
||||
next_checkpoint,
|
||||
) in connector_runner.run(checkpoint):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
@@ -571,6 +587,33 @@ def connector_document_extraction(
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
|
||||
if hierarchy_node_batch:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
upserted_nodes = upsert_hierarchy_nodes_batch(
|
||||
db_session=db_session,
|
||||
nodes=hierarchy_node_batch,
|
||||
source=db_connector.source,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Cache in Redis for fast ancestor resolution during doc processing
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
cache_entries = [
|
||||
HierarchyNodeCacheEntry.from_db_model(node)
|
||||
for node in upserted_nodes
|
||||
]
|
||||
cache_hierarchy_nodes_batch(
|
||||
redis_client=redis_client,
|
||||
source=db_connector.source,
|
||||
entries=cache_entries,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
|
||||
f"for attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# below is all document processing task, so if no batch we can just continue
|
||||
if not document_batch:
|
||||
continue
|
||||
@@ -600,34 +643,103 @@ def connector_document_extraction(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
# cc4a
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
# File system only - write directly to persistent storage,
|
||||
# skip chunking/embedding/Vespa but still track documents in DB
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Create metadata for the batch
|
||||
index_attempt_metadata = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
request_id=make_randomized_onyx_request_id("FSI"),
|
||||
structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}",
|
||||
batch_num=batch_num,
|
||||
)
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
# Upsert documents to PostgreSQL (document table + cc_pair relationship)
|
||||
# This is a subset of what docprocessing does - just DB tracking, no chunking/embedding
|
||||
index_doc_batch_prepare(
|
||||
documents=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True, # Documents already filtered during extraction
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
# Mark documents as indexed for the CC pair
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_ids=[doc.id for doc in doc_batch_cleaned],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
# Write documents to persistent file system
|
||||
# Use creator_id for user-segregated storage paths (sandbox isolation)
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id is None:
|
||||
raise ValueError(
|
||||
f"ConnectorCredentialPair {index_attempt.connector_credential_pair.id} "
|
||||
"must have a creator_id for persistent document storage"
|
||||
)
|
||||
user_id_str: str = str(creator_id)
|
||||
writer = get_persistent_document_writer(
|
||||
user_id=user_id_str,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
written_paths = writer.write_documents(doc_batch_cleaned)
|
||||
|
||||
# Update coordination directly (no docprocessing task)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
IndexingCoordination.update_batch_completion_and_docs(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=len(doc_batch_cleaned),
|
||||
new_docs_indexed=len(doc_batch_cleaned),
|
||||
total_chunks=0, # No chunks for file system mode
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Wrote documents to file system: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(written_paths)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
else:
|
||||
# REGULAR mode (default): Full pipeline - store and queue docprocessing
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
@@ -663,6 +775,24 @@ def connector_document_extraction(
|
||||
total_batches=batch_num,
|
||||
)
|
||||
|
||||
# Trigger file sync to user's sandbox (if running) - only for FILE_SYSTEM mode
|
||||
# This syncs the newly written documents from S3 to any running sandbox pod
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id:
|
||||
app.send_task(
|
||||
OnyxCeleryTask.SANDBOX_FILE_SYNC,
|
||||
kwargs={
|
||||
"user_id": str(creator_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -15,6 +16,11 @@ from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
# Full key: (document_id, chunk_ind, match_highlights)
|
||||
SearchDocKey = str | tuple[str, int, tuple[str, ...]]
|
||||
|
||||
|
||||
class ChatStateContainer:
|
||||
"""Container for accumulating state during LLM loop execution.
|
||||
@@ -40,6 +46,10 @@ class ChatStateContainer:
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
# Search doc collection - maps dedup key to SearchDoc for all docs from tool calls
|
||||
self._all_search_docs: dict[SearchDocKey, SearchDoc] = {}
|
||||
# Track which citation numbers were actually emitted during streaming
|
||||
self._emitted_citations: set[int] = set()
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -91,6 +101,54 @@ class ChatStateContainer:
|
||||
with self._lock:
|
||||
return self.is_clarification
|
||||
|
||||
@staticmethod
|
||||
def create_search_doc_key(
|
||||
search_doc: SearchDoc, use_simple_key: bool = True
|
||||
) -> SearchDocKey:
|
||||
"""Create a unique key for a SearchDoc for deduplication.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc to create a key for
|
||||
use_simple_key: If True (default), use only document_id for deduplication.
|
||||
If False, include chunk_ind and match_highlights so that the same
|
||||
document/chunk with different highlights are stored separately.
|
||||
"""
|
||||
if use_simple_key:
|
||||
return search_doc.document_id
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
def add_search_docs(
|
||||
self, search_docs: list[SearchDoc], use_simple_key: bool = True
|
||||
) -> None:
|
||||
"""Add search docs to the accumulated collection with deduplication.
|
||||
|
||||
Args:
|
||||
search_docs: List of SearchDoc objects to add
|
||||
use_simple_key: If True (default), deduplicate by document_id only.
|
||||
If False, deduplicate by document_id + chunk_ind + match_highlights.
|
||||
"""
|
||||
with self._lock:
|
||||
for doc in search_docs:
|
||||
key = self.create_search_doc_key(doc, use_simple_key)
|
||||
if key not in self._all_search_docs:
|
||||
self._all_search_docs[key] = doc
|
||||
|
||||
def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]:
|
||||
"""Thread-safe getter for all accumulated search docs (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._all_search_docs.copy()
|
||||
|
||||
def add_emitted_citation(self, citation_num: int) -> None:
|
||||
"""Add a citation number that was actually emitted during streaming."""
|
||||
with self._lock:
|
||||
self._emitted_citations.add(citation_num)
|
||||
|
||||
def get_emitted_citations(self) -> set[int]:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
|
||||
@@ -9,12 +9,6 @@ from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_processing_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_source_reset_task,
|
||||
)
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
@@ -37,7 +31,6 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.projects import check_project_ownership
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -359,39 +352,7 @@ def process_kg_commands(
|
||||
if not is_kg_config_settings_enabled_valid(kg_config_settings):
|
||||
return
|
||||
|
||||
# get Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_str = search_settings.index_name
|
||||
|
||||
if message == "kg_p":
|
||||
success = try_creating_kg_processing_task(tenant_id)
|
||||
if success:
|
||||
raise KGException("KG processing scheduled")
|
||||
else:
|
||||
raise KGException(
|
||||
"Cannot schedule another KG processing if one is already running "
|
||||
"or there are no documents to process"
|
||||
)
|
||||
|
||||
elif message.startswith("kg_rs_source"):
|
||||
msg_split = [x for x in message.split(":")]
|
||||
if len(msg_split) > 2:
|
||||
raise KGException("Invalid format for a source reset command")
|
||||
elif len(msg_split) == 2:
|
||||
source_name = msg_split[1].strip()
|
||||
elif len(msg_split) == 1:
|
||||
source_name = None
|
||||
else:
|
||||
raise KGException("Invalid format for a source reset command")
|
||||
|
||||
success = try_creating_kg_source_reset_task(tenant_id, source_name, index_str)
|
||||
if success:
|
||||
source_name = source_name or "all"
|
||||
raise KGException(f"KG index reset for source '{source_name}' scheduled")
|
||||
else:
|
||||
raise KGException("Cannot reset index while KG processing is running")
|
||||
|
||||
elif message == "kg_setup":
|
||||
if message == "kg_setup":
|
||||
populate_missing_default_entity_types__commit(db_session=db_session)
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
@@ -53,6 +53,50 @@ def update_citation_processor_from_tool_response(
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
|
||||
|
||||
def extract_citation_order_from_text(text: str) -> list[int]:
|
||||
"""Extract citation numbers from text in order of first appearance.
|
||||
|
||||
Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns
|
||||
the citation numbers in the order they first appear in the text.
|
||||
|
||||
Args:
|
||||
text: The text containing citations
|
||||
|
||||
Returns:
|
||||
List of citation numbers in order of first appearance (no duplicates)
|
||||
"""
|
||||
# Same pattern used in collapse_citations and DynamicCitationProcessor
|
||||
# Group 2 captures the number in double bracket format: [[1]], 【【1】】
|
||||
# Group 4 captures the numbers in single bracket format: [1], [1, 2]
|
||||
citation_pattern = re.compile(
|
||||
r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])"
|
||||
)
|
||||
seen: set[int] = set()
|
||||
order: list[int] = []
|
||||
|
||||
for match in citation_pattern.finditer(text):
|
||||
# Group 2 is for double bracket single number, group 4 is for single bracket
|
||||
if match.group(2):
|
||||
nums_str = match.group(2)
|
||||
elif match.group(4):
|
||||
nums_str = match.group(4)
|
||||
else:
|
||||
continue
|
||||
|
||||
for num_str in nums_str.split(","):
|
||||
num_str = num_str.strip()
|
||||
if num_str:
|
||||
try:
|
||||
num = int(num_str)
|
||||
if num not in seen:
|
||||
seen.add(num)
|
||||
order.append(num)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def collapse_citations(
|
||||
answer_text: str,
|
||||
existing_citation_mapping: CitationMapping,
|
||||
|
||||
@@ -45,6 +45,7 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
from onyx.tracing.framework.create import trace
|
||||
@@ -453,12 +454,16 @@ def run_llm_loop(
|
||||
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt and persona.system_prompt:
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
# Handles the case where user has checked off the "Replace base system prompt" checkbox
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
system_prompt = (
|
||||
ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
if persona.system_prompt
|
||||
else None
|
||||
)
|
||||
custom_agent_prompt_msg = None
|
||||
else:
|
||||
@@ -612,6 +617,7 @@ def run_llm_loop(
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
@@ -650,8 +656,15 @@ def run_llm_loop(
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
displayed_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
displayed_docs = tool_response.rich_response.displayed_docs
|
||||
|
||||
# Add ALL search docs to state container for DB persistence
|
||||
if search_docs:
|
||||
state_container.add_search_docs(search_docs)
|
||||
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
@@ -685,7 +698,7 @@ def run_llm_loop(
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=saved_response,
|
||||
search_docs=search_docs,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -432,7 +433,7 @@ def translate_history_to_llm_format(
|
||||
|
||||
for idx, msg in enumerate(history):
|
||||
# if the message is being added to the history
|
||||
if msg.message_type in [
|
||||
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.ASSISTANT,
|
||||
@@ -859,6 +860,11 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
@@ -985,6 +991,9 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
|
||||
@@ -42,7 +42,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
@@ -744,27 +743,16 @@ def llm_loop_completion_handle(
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
citation_to_doc=state_container.citation_to_doc,
|
||||
tool_calls=state_container.tool_calls,
|
||||
all_search_docs=state_container.get_all_search_docs(),
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
emitted_citations=state_container.get_emitted_citations(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import SearchDocKey
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import add_search_docs_to_chat_message
|
||||
from onyx.db.chat import add_search_docs_to_tool_call
|
||||
@@ -19,22 +20,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_search_doc_key(search_doc: SearchDoc) -> tuple[str, int, tuple[str, ...]]:
|
||||
"""
|
||||
Create a unique key for a SearchDoc that accounts for different versions of the same
|
||||
document/chunk with different match_highlights.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc pydantic model to create a key for
|
||||
|
||||
Returns:
|
||||
A tuple of (document_id, chunk_ind, sorted match_highlights) that uniquely identifies
|
||||
this specific version of the document
|
||||
"""
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -154,38 +139,36 @@ def save_chat_turn(
|
||||
message_text: str,
|
||||
reasoning_tokens: str | None,
|
||||
tool_calls: list[ToolCallInfo],
|
||||
citation_docs_info: list[CitationDocInfo],
|
||||
citation_to_doc: dict[int, SearchDoc],
|
||||
all_search_docs: dict[SearchDocKey, SearchDoc],
|
||||
db_session: Session,
|
||||
assistant_message: ChatMessage,
|
||||
is_clarification: bool = False,
|
||||
emitted_citations: set[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save a chat turn by populating the assistant_message and creating related entities.
|
||||
|
||||
This function:
|
||||
1. Updates the ChatMessage with text, reasoning tokens, and token count
|
||||
2. Creates SearchDoc entries from ToolCall search_docs (for tool calls that returned documents)
|
||||
3. Collects all unique SearchDocs from all tool calls and links them to ChatMessage
|
||||
4. Builds citation mapping from citation_docs_info
|
||||
5. Links all unique SearchDocs from tool calls to the ChatMessage
|
||||
2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
3. Builds tool_call -> search_doc mapping for displayed docs
|
||||
4. Builds citation mapping from citation_to_doc
|
||||
5. Links all unique SearchDocs to the ChatMessage
|
||||
6. Creates ToolCall entries and links SearchDocs to them
|
||||
7. Builds the citations mapping for the ChatMessage
|
||||
|
||||
Deduplication Logic:
|
||||
- SearchDocs are deduplicated using (document_id, chunk_ind, match_highlights) as the key
|
||||
- This ensures that the same document/chunk with different match_highlights (from different
|
||||
queries) are stored as separate SearchDoc entries
|
||||
- Each ToolCall and ChatMessage will map to the correct version of the SearchDoc that
|
||||
matches its specific query highlights
|
||||
|
||||
Args:
|
||||
message_text: The message content to save
|
||||
reasoning_tokens: Optional reasoning tokens for the message
|
||||
tool_calls: List of tool call information to create ToolCall entries (may include search_docs)
|
||||
citation_docs_info: List of citation document information for building citations mapping
|
||||
citation_to_doc: Mapping from citation number to SearchDoc for building citations
|
||||
all_search_docs: Pre-deduplicated search docs from ChatStateContainer
|
||||
db_session: Database session for persistence
|
||||
assistant_message: The ChatMessage object to populate (should already exist in DB)
|
||||
is_clarification: Whether this assistant message is a clarification question (deep research flow)
|
||||
emitted_citations: Set of citation numbers that were actually emitted during streaming.
|
||||
If provided, only citations in this set will be saved; others are filtered out.
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
@@ -200,53 +183,53 @@ def save_chat_turn(
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
# 2. Create SearchDoc entries from tool_calls
|
||||
# Build mapping from SearchDoc to DB SearchDoc ID
|
||||
# Use (document_id, chunk_ind, match_highlights) as key to avoid duplicates
|
||||
# while ensuring different versions with different highlights are stored separately
|
||||
search_doc_key_to_id: dict[tuple[str, int, tuple[str, ...]], int] = {}
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
# 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
search_doc_key_to_id: dict[SearchDocKey, int] = {}
|
||||
for key, search_doc_py in all_search_docs.items():
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
|
||||
# Process tool calls and their search docs
|
||||
# 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call)
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
for tool_call_info in tool_calls:
|
||||
if tool_call_info.search_docs:
|
||||
search_doc_ids_for_tool: list[int] = []
|
||||
for search_doc_py in tool_call_info.search_docs:
|
||||
# Create a unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
|
||||
# Check if we've already created this exact SearchDoc version
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[search_doc_key])
|
||||
key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
if key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[key])
|
||||
else:
|
||||
# Create new DB SearchDoc entry
|
||||
# Displayed doc not in all_search_docs - create it
|
||||
# This can happen if displayed_docs contains docs not in search_docs
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
|
||||
set(search_doc_ids_for_tool)
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
# Use a set to deduplicate by ID (since we've already deduplicated by key above)
|
||||
all_search_doc_ids_set: set[int] = set()
|
||||
for search_doc_ids in tool_call_to_search_doc_ids.values():
|
||||
all_search_doc_ids_set.update(search_doc_ids)
|
||||
# Collect all search doc IDs for ChatMessage linking
|
||||
all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values())
|
||||
|
||||
# 4. Build citation mapping from citation_docs_info
|
||||
# 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID
|
||||
# Only include citations that were actually emitted during streaming
|
||||
citation_number_to_search_doc_id: dict[int, int] = {}
|
||||
|
||||
for citation_doc_info in citation_docs_info:
|
||||
# Extract SearchDoc pydantic model
|
||||
search_doc_py = citation_doc_info.search_doc
|
||||
for citation_num, search_doc_py in citation_to_doc.items():
|
||||
# Skip citations that weren't actually emitted (if emitted_citations is provided)
|
||||
if emitted_citations is not None and citation_num not in emitted_citations:
|
||||
continue
|
||||
|
||||
# Create the unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
|
||||
# Get the search doc ID (should already exist from processing tool_calls)
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
@@ -283,10 +266,7 @@ def save_chat_turn(
|
||||
all_search_doc_ids_set.add(db_search_doc_id)
|
||||
|
||||
# Build mapping from citation number to search doc ID
|
||||
if citation_doc_info.citation_number is not None:
|
||||
citation_number_to_search_doc_id[citation_doc_info.citation_number] = (
|
||||
db_search_doc_id
|
||||
)
|
||||
citation_number_to_search_doc_id[citation_num] = db_search_doc_id
|
||||
|
||||
# 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage
|
||||
final_search_doc_ids: list[int] = list(all_search_doc_ids_set)
|
||||
@@ -306,23 +286,10 @@ def save_chat_turn(
|
||||
tool_call_to_search_doc_ids=tool_call_to_search_doc_ids,
|
||||
)
|
||||
|
||||
# 7. Build citations mapping from citation_docs_info
|
||||
# Any citation_doc_info with a citation_number appeared in the text and should be mapped
|
||||
citations: dict[int, int] = {}
|
||||
for citation_doc_info in citation_docs_info:
|
||||
if citation_doc_info.citation_number is not None:
|
||||
search_doc_id = citation_number_to_search_doc_id.get(
|
||||
citation_doc_info.citation_number
|
||||
)
|
||||
if search_doc_id is not None:
|
||||
citations[citation_doc_info.citation_number] = search_doc_id
|
||||
else:
|
||||
logger.warning(
|
||||
f"Citation number {citation_doc_info.citation_number} found in citation_docs_info "
|
||||
f"but no matching search doc ID in mapping"
|
||||
)
|
||||
|
||||
assistant_message.citations = citations if citations else None
|
||||
# 7. Build citations mapping - use the mapping we already built in step 4
|
||||
assistant_message.citations = (
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -207,9 +207,23 @@ OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
|
||||
ENABLE_OPENSEARCH_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
@@ -399,7 +413,7 @@ CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int(
|
||||
os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4
|
||||
)
|
||||
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
# Consolidated background worker (light, docprocessing, docfetching, heavy, monitoring, user_file_processing)
|
||||
# separate workers' defaults: light=24, docprocessing=6, docfetching=1, heavy=4, kg=2, monitoring=1, user_file=2
|
||||
# Total would be 40, but we use a more conservative default of 20 for the consolidated worker
|
||||
CELERY_WORKER_BACKGROUND_CONCURRENCY = int(
|
||||
@@ -411,10 +425,6 @@ CELERY_WORKER_HEAVY_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4
|
||||
)
|
||||
|
||||
CELERY_WORKER_KG_PROCESSING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_KG_PROCESSING_CONCURRENCY") or 2
|
||||
)
|
||||
|
||||
CELERY_WORKER_MONITORING_CONCURRENCY = int(
|
||||
os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1
|
||||
)
|
||||
@@ -738,6 +748,10 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
LOG_ONYX_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
|
||||
PROMPT_CACHE_CHAT_HISTORY = (
|
||||
os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -1016,3 +1030,25 @@ INSTANCE_TYPE = (
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
|
||||
## Stripe Configuration
|
||||
# URL to fetch the Stripe publishable key from a public S3 bucket.
|
||||
# Publishable keys are safe to expose publicly - they can only initialize
|
||||
# Stripe.js and tokenize payment info, not make charges or access data.
|
||||
STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
# Persistent Document Storage Configuration
|
||||
# When enabled, indexed documents are written to local filesystem with hierarchical structure
|
||||
PERSISTENT_DOCUMENT_STORAGE_ENABLED = (
|
||||
os.environ.get("PERSISTENT_DOCUMENT_STORAGE_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Base directory path for persistent document storage (local filesystem)
|
||||
# Example: /var/onyx/indexed-docs or /app/indexed-docs
|
||||
PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get(
|
||||
"PERSISTENT_DOCUMENT_STORAGE_PATH", "/app/indexed-docs"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
NUM_RETURNED_HITS = 50
|
||||
|
||||
@@ -80,7 +80,6 @@ POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching"
|
||||
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME = "celery_worker_background"
|
||||
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
|
||||
POSTGRES_CELERY_WORKER_KG_PROCESSING_APP_NAME = "celery_worker_kg_processing"
|
||||
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
|
||||
POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = (
|
||||
"celery_worker_user_file_processing"
|
||||
@@ -241,6 +240,7 @@ class NotificationType(str, Enum):
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -327,6 +327,7 @@ class FileOrigin(str, Enum):
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
QUERY_HISTORY_CSV = "query_history_csv"
|
||||
SANDBOX_SNAPSHOT = "sandbox_snapshot"
|
||||
USER_FILE = "user_file"
|
||||
|
||||
|
||||
@@ -344,6 +345,7 @@ class MilestoneRecordType(str, Enum):
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
@@ -367,6 +369,7 @@ class OnyxCeleryQueues:
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
CONNECTOR_HIERARCHY_FETCHING = "connector_hierarchy_fetching"
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# User file processing queue
|
||||
@@ -380,8 +383,8 @@ class OnyxCeleryQueues:
|
||||
# Monitoring queue
|
||||
MONITORING = "monitoring"
|
||||
|
||||
# KG processing queue
|
||||
KG_PROCESSING = "kg_processing"
|
||||
# Sandbox processing queue
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
@@ -389,6 +392,7 @@ class OnyxRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
CHECK_HIERARCHY_FETCHING_BEAT_LOCK = "da_lock:check_hierarchy_fetching_beat"
|
||||
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
|
||||
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
|
||||
CHECK_INDEX_ATTEMPT_CLEANUP_BEAT_LOCK = "da_lock:check_index_attempt_cleanup_beat"
|
||||
@@ -417,9 +421,6 @@ class OnyxRedisLocks:
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
|
||||
# KG processing
|
||||
KG_PROCESSING_LOCK = "da_lock:kg_processing"
|
||||
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
@@ -431,6 +432,10 @@ class OnyxRedisLocks:
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
|
||||
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -447,9 +452,6 @@ class OnyxRedisSignals:
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING_BEAT_LOCK = "da_lock:check_kg_processing_beat"
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
ACTIVE_FENCES = "active_fences"
|
||||
@@ -493,6 +495,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_HIERARCHY_FETCHING = "check_for_hierarchy_fetching"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_AUTO_LLM_UPDATE = "check_for_auto_llm_update"
|
||||
@@ -534,6 +537,7 @@ class OnyxCeleryTask:
|
||||
DOCPROCESSING_TASK = "docprocessing_task"
|
||||
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
CONNECTOR_HIERARCHY_FETCHING_TASK = "connector_hierarchy_fetching_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
|
||||
@@ -549,12 +553,12 @@ class OnyxCeleryTask:
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
# KG processing
|
||||
CHECK_KG_PROCESSING = "check_kg_processing"
|
||||
KG_PROCESSING = "kg_processing"
|
||||
KG_CLUSTERING_ONLY = "kg_clustering_only"
|
||||
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
|
||||
KG_RESET_SOURCE_INDEX = "kg_reset_source_index"
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
# Sandbox file sync
|
||||
SANDBOX_FILE_SYNC = "sandbox_file_sync"
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -419,7 +420,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
record_documents: list[Document] = []
|
||||
record_documents: list[Document | HierarchyNode] = []
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -56,7 +57,7 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
docs_batch: list[Document] = []
|
||||
docs_batch: list[Document | HierarchyNode] = []
|
||||
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
|
||||
|
||||
for task in tasks:
|
||||
@@ -116,5 +117,8 @@ if __name__ == "__main__":
|
||||
latest_docs = connector.poll_source(one_day_ago, current)
|
||||
for docs in latest_docs:
|
||||
for doc in docs:
|
||||
print(doc.id)
|
||||
if isinstance(doc, HierarchyNode):
|
||||
print("hierarchynode:", doc.display_name)
|
||||
else:
|
||||
print(doc.id)
|
||||
logger.notice("Asana connector test completed")
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -271,9 +272,9 @@ class BitbucketConnector(
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Iterator[list[SlimDocument]]:
|
||||
) -> Iterator[list[SlimDocument | HierarchyNode]]:
|
||||
"""Return only document IDs for all existing pull requests."""
|
||||
batch: list[SlimDocument] = []
|
||||
batch: list[SlimDocument | HierarchyNode] = []
|
||||
params = self._build_params(
|
||||
fields=SLIM_PR_LIST_RESPONSE_FIELDS,
|
||||
start=start,
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
@@ -377,7 +378,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
batch: list[Document | HierarchyNode] = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
@@ -616,6 +617,10 @@ if __name__ == "__main__":
|
||||
for document_batch in document_batch_generator:
|
||||
print("First batch of documents:")
|
||||
for doc in document_batch:
|
||||
if isinstance(doc, HierarchyNode):
|
||||
print("hierarchynode:", doc.display_name)
|
||||
continue
|
||||
|
||||
print(f"Document ID: {doc.id}")
|
||||
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
||||
print(f"Source: {doc.source}")
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
|
||||
@@ -47,7 +48,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
start_ind: int,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], int]:
|
||||
) -> tuple[list[Document | HierarchyNode], int]:
|
||||
params = {
|
||||
"count": str(batch_size),
|
||||
"offset": str(start_ind),
|
||||
@@ -65,7 +66,9 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
doc_batch = [transformer(bookstack_client, item) for item in batch]
|
||||
doc_batch: list[Document | HierarchyNode] = [
|
||||
transformer(bookstack_client, item) for item in batch
|
||||
]
|
||||
|
||||
return doc_batch, len(batch)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -80,7 +81,7 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
page: int = 0
|
||||
params = {
|
||||
"include_markdown_description": "true",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -46,9 +47,11 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -62,6 +65,7 @@ _PAGE_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
"ancestors", # For hierarchy node tracking
|
||||
]
|
||||
_ATTACHMENT_EXPANSION_FIELDS = [
|
||||
"version",
|
||||
@@ -133,6 +137,9 @@ class ConfluenceConnector(
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
|
||||
# Track hierarchy nodes we've already yielded to avoid duplicates
|
||||
self.seen_hierarchy_node_raw_ids: set[str] = set()
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
"""
|
||||
@@ -183,6 +190,139 @@ class ConfluenceConnector(
|
||||
logger.info(f"Setting allow_images to {value}.")
|
||||
self.allow_images = value
|
||||
|
||||
def _yield_space_hierarchy_nodes(
|
||||
self,
|
||||
) -> Generator[HierarchyNode, None, None]:
|
||||
"""Yield hierarchy nodes for all spaces we're indexing."""
|
||||
space_keys = [self.space] if self.space else None
|
||||
|
||||
for space in self.confluence_client.retrieve_confluence_spaces(
|
||||
space_keys=space_keys,
|
||||
limit=50,
|
||||
):
|
||||
space_key = space.get("key")
|
||||
if not space_key or space_key in self.seen_hierarchy_node_raw_ids:
|
||||
continue
|
||||
|
||||
self.seen_hierarchy_node_raw_ids.add(space_key)
|
||||
|
||||
# Build space link
|
||||
space_link = f"{self.wiki_base}/spaces/{space_key}"
|
||||
|
||||
yield HierarchyNode(
|
||||
raw_node_id=space_key,
|
||||
raw_parent_id=None, # Parent is SOURCE
|
||||
display_name=space.get("name", space_key),
|
||||
link=space_link,
|
||||
node_type=HierarchyNodeType.SPACE,
|
||||
)
|
||||
|
||||
def _yield_ancestor_hierarchy_nodes(
|
||||
self,
|
||||
page: dict[str, Any],
|
||||
) -> Generator[HierarchyNode, None, None]:
|
||||
"""Yield hierarchy nodes for all unseen ancestors of this page.
|
||||
|
||||
Any page that appears as an ancestor of another page IS a hierarchy node
|
||||
(it has at least one child - the page we're currently processing).
|
||||
|
||||
This ensures parent nodes are always yielded before child documents.
|
||||
"""
|
||||
ancestors = page.get("ancestors", [])
|
||||
space_key = page.get("space", {}).get("key")
|
||||
|
||||
# Ensure space is yielded first (if not already)
|
||||
if space_key and space_key not in self.seen_hierarchy_node_raw_ids:
|
||||
self.seen_hierarchy_node_raw_ids.add(space_key)
|
||||
space = page.get("space", {})
|
||||
yield HierarchyNode(
|
||||
raw_node_id=space_key,
|
||||
raw_parent_id=None, # Parent is SOURCE
|
||||
display_name=space.get("name", space_key),
|
||||
link=f"{self.wiki_base}/spaces/{space_key}",
|
||||
node_type=HierarchyNodeType.SPACE,
|
||||
)
|
||||
|
||||
# Walk through ancestors (root to immediate parent)
|
||||
for i, ancestor in enumerate(ancestors):
|
||||
ancestor_id = str(ancestor.get("id"))
|
||||
if ancestor_id in self.seen_hierarchy_node_raw_ids:
|
||||
continue
|
||||
|
||||
self.seen_hierarchy_node_raw_ids.add(ancestor_id)
|
||||
|
||||
# Determine parent of this ancestor
|
||||
if i == 0:
|
||||
# First ancestor - parent is the space
|
||||
parent_raw_id = space_key
|
||||
else:
|
||||
# Parent is the previous ancestor
|
||||
parent_raw_id = str(ancestors[i - 1].get("id"))
|
||||
|
||||
# Build link from ancestor's _links
|
||||
ancestor_link = None
|
||||
if "_links" in ancestor and "webui" in ancestor["_links"]:
|
||||
ancestor_link = build_confluence_document_id(
|
||||
self.wiki_base, ancestor["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
yield HierarchyNode(
|
||||
raw_node_id=ancestor_id,
|
||||
raw_parent_id=parent_raw_id,
|
||||
display_name=ancestor.get("title", f"Page {ancestor_id}"),
|
||||
link=ancestor_link,
|
||||
node_type=HierarchyNodeType.PAGE,
|
||||
)
|
||||
|
||||
def _get_parent_hierarchy_raw_id(self, page: dict[str, Any]) -> str | None:
|
||||
"""Get the raw hierarchy node ID of this page's parent.
|
||||
|
||||
Returns:
|
||||
- Parent page ID if page has a parent page (last item in ancestors)
|
||||
- Space key if page is at top level of space
|
||||
- None if we can't determine
|
||||
"""
|
||||
ancestors = page.get("ancestors", [])
|
||||
if ancestors:
|
||||
# Last ancestor is the immediate parent page
|
||||
return str(ancestors[-1].get("id"))
|
||||
|
||||
# Top-level page - parent is the space
|
||||
return page.get("space", {}).get("key")
|
||||
|
||||
def _maybe_yield_page_hierarchy_node(
|
||||
self, page: dict[str, Any]
|
||||
) -> HierarchyNode | None:
|
||||
"""Yield a hierarchy node for this page if not already yielded.
|
||||
|
||||
Used when a page has attachments - attachments are children of the page
|
||||
in the hierarchy, so the page must be a hierarchy node.
|
||||
"""
|
||||
page_id = _get_page_id(page)
|
||||
if page_id in self.seen_hierarchy_node_raw_ids:
|
||||
return None
|
||||
|
||||
self.seen_hierarchy_node_raw_ids.add(page_id)
|
||||
|
||||
# Build page link
|
||||
page_link = None
|
||||
if "_links" in page and "webui" in page["_links"]:
|
||||
page_link = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Get parent hierarchy ID
|
||||
parent_raw_id = self._get_parent_hierarchy_raw_id(page)
|
||||
|
||||
return HierarchyNode(
|
||||
raw_node_id=page_id,
|
||||
raw_parent_id=parent_raw_id,
|
||||
display_name=page.get("title", f"Page {page_id}"),
|
||||
link=page_link,
|
||||
node_type=HierarchyNodeType.PAGE,
|
||||
document_id=page_link, # Page is also a document
|
||||
)
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -354,6 +494,9 @@ class ConfluenceConnector(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
)
|
||||
|
||||
# Determine parent hierarchy node
|
||||
parent_hierarchy_raw_node_id = self._get_parent_hierarchy_raw_id(page)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=page_url,
|
||||
@@ -363,6 +506,7 @@ class ConfluenceConnector(
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
|
||||
@@ -382,18 +526,22 @@ class ConfluenceConnector(
|
||||
page: dict[str, Any],
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> tuple[list[Document], list[ConnectorFailure]]:
|
||||
) -> tuple[list[Document | HierarchyNode], list[ConnectorFailure]]:
|
||||
"""
|
||||
Inline attachments are added directly to the document as text or image sections by
|
||||
this function. The returned documents/connectorfailures are for non-inline attachments
|
||||
and those at the end of the page.
|
||||
|
||||
If there are valid attachments, the page itself is yielded as a hierarchy node
|
||||
(since attachments are children of the page in the hierarchy).
|
||||
"""
|
||||
attachment_query = self._construct_attachment_query(
|
||||
_get_page_id(page), start, end
|
||||
)
|
||||
attachment_failures: list[ConnectorFailure] = []
|
||||
attachment_docs: list[Document] = []
|
||||
attachment_docs: list[Document | HierarchyNode] = []
|
||||
page_url = ""
|
||||
page_hierarchy_node_yielded = False
|
||||
|
||||
try:
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -487,6 +635,9 @@ class ConfluenceConnector(
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
|
||||
# Attachments have their parent page as the hierarchy parent
|
||||
attachment_parent_hierarchy_raw_id = _get_page_id(page)
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
@@ -500,7 +651,19 @@ class ConfluenceConnector(
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
parent_hierarchy_raw_node_id=attachment_parent_hierarchy_raw_id,
|
||||
)
|
||||
|
||||
# If this is the first valid attachment, yield the page as a
|
||||
# hierarchy node (attachments are children of the page)
|
||||
if not page_hierarchy_node_yielded:
|
||||
page_hierarchy_node = self._maybe_yield_page_hierarchy_node(
|
||||
page
|
||||
)
|
||||
if page_hierarchy_node:
|
||||
attachment_docs.append(page_hierarchy_node)
|
||||
page_hierarchy_node_yielded = True
|
||||
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -568,7 +731,8 @@ class ConfluenceConnector(
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> CheckpointOutput[ConfluenceCheckpoint]:
|
||||
"""
|
||||
Yields batches of Documents. For each page:
|
||||
Yields batches of Documents and HierarchyNodes. For each page:
|
||||
- Yield hierarchy nodes for spaces and ancestor pages (parent-before-child ordering)
|
||||
- Create a Document with 1 Section for the page text/comments
|
||||
- Then fetch attachments. For each attachment:
|
||||
- Attempt to convert it with convert_attachment_to_content(...)
|
||||
@@ -576,6 +740,10 @@ class ConfluenceConnector(
|
||||
"""
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
|
||||
# Yield space hierarchy nodes FIRST (only once per connector run)
|
||||
if not checkpoint.next_page_url:
|
||||
yield from self._yield_space_hierarchy_nodes()
|
||||
|
||||
# use "start" when last_updated is 0 or for confluence server
|
||||
start_ts = start
|
||||
page_query_url = checkpoint.next_page_url or self._build_page_retrieval_url(
|
||||
@@ -592,6 +760,9 @@ class ConfluenceConnector(
|
||||
limit=self.batch_size,
|
||||
next_page_callback=store_next_page_url,
|
||||
):
|
||||
# Yield hierarchy nodes for all ancestors (parent-before-child ordering)
|
||||
yield from self._yield_ancestor_hierarchy_nodes(page)
|
||||
|
||||
# Build doc from page
|
||||
doc_or_failure = self._convert_page_to_document(page)
|
||||
|
||||
@@ -700,7 +871,7 @@ class ConfluenceConnector(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
include_permissions: bool = True,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
doc_metadata_list: list[SlimDocument | HierarchyNode] = []
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
space_level_access_info: dict[str, ExternalAccess] = {}
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -30,15 +31,16 @@ def batched_doc_ids(
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
batch: set[str] = set()
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
for document, hierarchy_node, failure, next_checkpoint in CheckpointOutputWrapper[
|
||||
CT
|
||||
]()(checkpoint_connector_generator):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure and failure.failed_document and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
# HierarchyNodes don't have IDs that need to be batched for doc processing
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
@@ -63,7 +65,9 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
self,
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
) -> Generator[
|
||||
tuple[Document | None, ConnectorFailure | None, CT | None],
|
||||
tuple[
|
||||
Document | None, HierarchyNode | None, ConnectorFailure | None, CT | None
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
@@ -74,22 +78,22 @@ class CheckpointOutputWrapper(Generic[CT]):
|
||||
self.next_checkpoint = yield from checkpoint_connector_generator
|
||||
return self.next_checkpoint # not used
|
||||
|
||||
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(document_or_failure, Document):
|
||||
yield document_or_failure, None, None
|
||||
elif isinstance(document_or_failure, ConnectorFailure):
|
||||
yield None, document_or_failure, None
|
||||
for item in _inner_wrapper(checkpoint_connector_generator):
|
||||
if isinstance(item, Document):
|
||||
yield item, None, None, None
|
||||
elif isinstance(item, HierarchyNode):
|
||||
yield None, item, None, None
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
yield None, None, item, None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid document_or_failure type: {type(document_or_failure)}"
|
||||
)
|
||||
raise ValueError(f"Invalid connector output type: {type(item)}")
|
||||
|
||||
if self.next_checkpoint is None:
|
||||
raise RuntimeError(
|
||||
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
|
||||
)
|
||||
|
||||
yield None, None, self.next_checkpoint
|
||||
yield None, None, None, self.next_checkpoint
|
||||
|
||||
|
||||
class ConnectorRunner(Generic[CT]):
|
||||
@@ -119,13 +123,27 @@ class ConnectorRunner(Generic[CT]):
|
||||
self.include_permissions = include_permissions
|
||||
|
||||
self.doc_batch: list[Document] = []
|
||||
self.hierarchy_node_batch: list[HierarchyNode] = []
|
||||
|
||||
def run(self, checkpoint: CT) -> Generator[
|
||||
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
|
||||
tuple[
|
||||
list[Document] | None,
|
||||
list[HierarchyNode] | None,
|
||||
ConnectorFailure | None,
|
||||
CT | None,
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
"""
|
||||
Yields batches of Documents, HierarchyNodes, failures, and checkpoints.
|
||||
|
||||
Returns tuples of:
|
||||
- (doc_batch, None, None, None) - batch of documents
|
||||
- (None, hierarchy_batch, None, None) - batch of hierarchy nodes
|
||||
- (None, None, failure, None) - a connector failure
|
||||
- (None, None, None, checkpoint) - new checkpoint
|
||||
"""
|
||||
try:
|
||||
if isinstance(self.connector, CheckpointedConnector):
|
||||
if self.time_range is None:
|
||||
@@ -151,25 +169,47 @@ class ConnectorRunner(Generic[CT]):
|
||||
)
|
||||
next_checkpoint: CT | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None and isinstance(document, Document):
|
||||
for (
|
||||
document,
|
||||
hierarchy_node,
|
||||
failure,
|
||||
next_checkpoint,
|
||||
) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
yield None, failure, None
|
||||
if hierarchy_node is not None:
|
||||
self.hierarchy_node_batch.append(hierarchy_node)
|
||||
|
||||
if failure is not None:
|
||||
yield None, None, failure, None
|
||||
|
||||
# Yield hierarchy nodes batch if it reaches batch_size
|
||||
# (yield nodes before docs to maintain parent-before-child invariant)
|
||||
if len(self.hierarchy_node_batch) >= self.batch_size:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
|
||||
# Yield document batch if it reaches batch_size
|
||||
# First flush any pending hierarchy nodes to ensure parents exist
|
||||
if len(self.doc_batch) >= self.batch_size:
|
||||
yield self.doc_batch, None, None
|
||||
if len(self.hierarchy_node_batch) > 0:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
yield self.doc_batch, None, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
# yield remaining hierarchy nodes first (parents before children)
|
||||
if len(self.hierarchy_node_batch) > 0:
|
||||
yield None, self.hierarchy_node_batch, None, None
|
||||
self.hierarchy_node_batch = []
|
||||
|
||||
# yield remaining documents
|
||||
if len(self.doc_batch) > 0:
|
||||
yield self.doc_batch, None, None
|
||||
yield self.doc_batch, None, None, None
|
||||
self.doc_batch = []
|
||||
|
||||
yield None, None, next_checkpoint
|
||||
yield None, None, None, next_checkpoint
|
||||
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
|
||||
@@ -183,18 +223,26 @@ class ConnectorRunner(Generic[CT]):
|
||||
if self.time_range is None:
|
||||
raise ValueError("time_range is required for PollConnector")
|
||||
|
||||
for document_batch in self.connector.poll_source(
|
||||
for batch in self.connector.poll_source(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
):
|
||||
yield document_batch, None, None
|
||||
docs, nodes = self._separate_batch(batch)
|
||||
if nodes:
|
||||
yield None, nodes, None, None
|
||||
if docs:
|
||||
yield docs, None, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
yield None, None, None, finished_checkpoint
|
||||
elif isinstance(self.connector, LoadConnector):
|
||||
for document_batch in self.connector.load_from_state():
|
||||
yield document_batch, None, None
|
||||
for batch in self.connector.load_from_state():
|
||||
docs, nodes = self._separate_batch(batch)
|
||||
if nodes:
|
||||
yield None, nodes, None, None
|
||||
if docs:
|
||||
yield docs, None, None, None
|
||||
|
||||
yield None, None, finished_checkpoint
|
||||
yield None, None, None, finished_checkpoint
|
||||
else:
|
||||
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
|
||||
except Exception:
|
||||
@@ -219,3 +267,16 @@ class ConnectorRunner(Generic[CT]):
|
||||
f"local_vars below -> \n{local_vars_str[:1024]}"
|
||||
)
|
||||
raise
|
||||
|
||||
def _separate_batch(
|
||||
self, batch: list[Document | HierarchyNode]
|
||||
) -> tuple[list[Document], list[HierarchyNode]]:
|
||||
"""Separate a mixed batch into Documents and HierarchyNodes."""
|
||||
docs: list[Document] = []
|
||||
nodes: list[HierarchyNode] = []
|
||||
for item in batch:
|
||||
if isinstance(item, Document):
|
||||
docs.append(item)
|
||||
elif isinstance(item, HierarchyNode):
|
||||
nodes.append(item)
|
||||
return docs, nodes
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -278,7 +279,7 @@ class DiscordConnector(PollConnector, LoadConnector):
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
@@ -193,7 +194,7 @@ class DiscourseConnector(PollConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 0
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
doc_batch: list[Document | HierarchyNode] = []
|
||||
for topic_id in topic_ids:
|
||||
doc_batch.append(self._get_doc_from_topic(topic_id))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -119,7 +120,7 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
workspace_id = self._get_workspace_id_by_name()
|
||||
articles = self._get_articles_with_category(workspace_id)
|
||||
|
||||
doc_batch: List[Document] = []
|
||||
doc_batch: List[Document | HierarchyNode] = []
|
||||
|
||||
for article in articles:
|
||||
article_details = self._make_request(
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -80,7 +81,7 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
batch: list[Document | HierarchyNode] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
@@ -740,7 +741,7 @@ class DrupalWikiConnector(
|
||||
Returns:
|
||||
Generator yielding batches of SlimDocument objects.
|
||||
"""
|
||||
slim_docs: list[SlimDocument] = []
|
||||
slim_docs: list[SlimDocument | HierarchyNode] = []
|
||||
logger.info(
|
||||
f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
@@ -278,8 +279,8 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
current_batch: list[Document] = []
|
||||
) -> Generator[list[Document | HierarchyNode], None, None]:
|
||||
current_batch: list[Document | HierarchyNode] = []
|
||||
|
||||
# Iterate through yielded files and filter them
|
||||
for file in self._get_files_list(self.folder_path):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user