mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
402 Commits
v2.6.0-bet
...
thread_sen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5848975679 | ||
|
|
dcc330010e | ||
|
|
d0f5f1f5ae | ||
|
|
3e475993ff | ||
|
|
7c2b5fa822 | ||
|
|
409cfdc788 | ||
|
|
7a9a132739 | ||
|
|
33bad8c37b | ||
|
|
9241ff7a75 | ||
|
|
0a25bc30ec | ||
|
|
e359732f4c | ||
|
|
be47866a4d | ||
|
|
8a20540559 | ||
|
|
e6e1f2860a | ||
|
|
fc3f433df7 | ||
|
|
016caf453b | ||
|
|
a9de25053f | ||
|
|
8ef8dfdeb7 | ||
|
|
0643b626d9 | ||
|
|
64a0eb52e0 | ||
|
|
b82ffc82cf | ||
|
|
b3014b9911 | ||
|
|
439707c395 | ||
|
|
65351aa8bd | ||
|
|
b44ee07eaf | ||
|
|
065d391c08 | ||
|
|
14fe3b375f | ||
|
|
bb1b96dded | ||
|
|
9f949ae2d9 | ||
|
|
975c0e8009 | ||
|
|
3dfb38c460 | ||
|
|
a1512a0485 | ||
|
|
8ea3bacd38 | ||
|
|
6b560b8162 | ||
|
|
3b750939ed | ||
|
|
bd4cb17a48 | ||
|
|
485cd9a311 | ||
|
|
2108c72353 | ||
|
|
98f43fb6ab | ||
|
|
e112ebb371 | ||
|
|
f88cbcfe27 | ||
|
|
0df0b10d3a | ||
|
|
ed0d12452a | ||
|
|
dc7cb80594 | ||
|
|
4312b24945 | ||
|
|
afd920bb33 | ||
|
|
d009b12aa7 | ||
|
|
596b3d9f3e | ||
|
|
1981c912b7 | ||
|
|
68b1bb8448 | ||
|
|
4676b5017f | ||
|
|
eb7b6a5ce1 | ||
|
|
87d6df2621 | ||
|
|
13b4108b53 | ||
|
|
13e806b625 | ||
|
|
f4f7839d84 | ||
|
|
2dbf1c3b1f | ||
|
|
288d4147c3 | ||
|
|
fee27b2274 | ||
|
|
340e938627 | ||
|
|
6faa47e0f7 | ||
|
|
ba6801f5af | ||
|
|
d7447eb8af | ||
|
|
196f890a68 | ||
|
|
3ac96572c3 | ||
|
|
3d8ae22b3a | ||
|
|
233d06ec0e | ||
|
|
9ff82ac740 | ||
|
|
b15f01fd78 | ||
|
|
6480cf6738 | ||
|
|
c521a4397a | ||
|
|
41a8d86df3 | ||
|
|
735cf926e4 | ||
|
|
035e73655f | ||
|
|
f317420f58 | ||
|
|
d50a84f2e4 | ||
|
|
9b441e3686 | ||
|
|
c4c1e16f19 | ||
|
|
9044e0f5fa | ||
|
|
a180e1337b | ||
|
|
6ca72291bc | ||
|
|
c23046f7c0 | ||
|
|
d5f66ac146 | ||
|
|
241fc8f877 | ||
|
|
f1ea41b519 | ||
|
|
ed3f72bc75 | ||
|
|
2247e3cf8e | ||
|
|
47c49d86e8 | ||
|
|
8c11330d46 | ||
|
|
22ac22c17d | ||
|
|
c0a6a0fb4a | ||
|
|
7f31a39dc2 | ||
|
|
f1f61690e3 | ||
|
|
8c3e17bbe5 | ||
|
|
a1ab3678a0 | ||
|
|
2d79ed7bb4 | ||
|
|
f472fd763e | ||
|
|
e47b2fccb4 | ||
|
|
17a6fc4ebf | ||
|
|
391c8c5cf7 | ||
|
|
d0e3ee1055 | ||
|
|
dc760cf580 | ||
|
|
d49931fce1 | ||
|
|
41d1d265a0 | ||
|
|
45a2207662 | ||
|
|
725ed6a523 | ||
|
|
2452671420 | ||
|
|
a4a767f146 | ||
|
|
8304fbd14c | ||
|
|
7db7d4c965 | ||
|
|
2cc2b5aee9 | ||
|
|
0c35ffe468 | ||
|
|
adece3f812 | ||
|
|
b44349e67d | ||
|
|
3134e5f840 | ||
|
|
5b8223b6af | ||
|
|
30ab85f5a0 | ||
|
|
daa343c30b | ||
|
|
c67936a4c1 | ||
|
|
4578c268ed | ||
|
|
7658917fe8 | ||
|
|
fd4695d5bd | ||
|
|
a25362a709 | ||
|
|
1eb4962861 | ||
|
|
aa1c956608 | ||
|
|
19e5c47f85 | ||
|
|
872a2ed58a | ||
|
|
42047a4dce | ||
|
|
a3a9847d76 | ||
|
|
3ade17c380 | ||
|
|
9150ba1905 | ||
|
|
cb14e84750 | ||
|
|
c916517342 | ||
|
|
45b902c950 | ||
|
|
981b43e47b | ||
|
|
b5c45cbce0 | ||
|
|
451f10343e | ||
|
|
ceeed2a562 | ||
|
|
bcc7a7f264 | ||
|
|
972ef34b92 | ||
|
|
9d11d1f218 | ||
|
|
4db68853cd | ||
|
|
b08fafc66b | ||
|
|
1e61bf401e | ||
|
|
0541c2989d | ||
|
|
743b996698 | ||
|
|
16e77aebfc | ||
|
|
944f4a2464 | ||
|
|
67db7c0346 | ||
|
|
8e47cd4e4f | ||
|
|
e8a4fca0a3 | ||
|
|
6d783ca691 | ||
|
|
283317bd65 | ||
|
|
2afbc74224 | ||
|
|
5b273de8be | ||
|
|
a0a24147b5 | ||
|
|
fd31da3159 | ||
|
|
cd76ac876b | ||
|
|
8f205172eb | ||
|
|
be70fa21e3 | ||
|
|
0687bddb6f | ||
|
|
73091118e3 | ||
|
|
bf8590a637 | ||
|
|
8a6d597496 | ||
|
|
f0bc538f60 | ||
|
|
0b6d9347bb | ||
|
|
415538f9f8 | ||
|
|
969261f314 | ||
|
|
eaa4d5d434 | ||
|
|
19e6900d96 | ||
|
|
f3535b94a0 | ||
|
|
383aa222ba | ||
|
|
f32b21400f | ||
|
|
5d5e71900e | ||
|
|
06ce7484b3 | ||
|
|
700db01b33 | ||
|
|
521e9f108f | ||
|
|
1dfb62bb69 | ||
|
|
14a1b3d197 | ||
|
|
f3feac84f3 | ||
|
|
d6e7c11c92 | ||
|
|
d66eef36d3 | ||
|
|
05fd974968 | ||
|
|
ad882e587d | ||
|
|
f2b1f20161 | ||
|
|
6ec3b4c6cf | ||
|
|
529a2e0336 | ||
|
|
35602519c5 | ||
|
|
7e0b773247 | ||
|
|
924b5e5c70 | ||
|
|
cfcb09070d | ||
|
|
27b0fee3c4 | ||
|
|
5617e86b14 | ||
|
|
b909eb0205 | ||
|
|
2a821134c0 | ||
|
|
ad632e4440 | ||
|
|
153e313021 | ||
|
|
abc80d7feb | ||
|
|
1a96e894fe | ||
|
|
5a09a73df8 | ||
|
|
02723291b3 | ||
|
|
324388fefc | ||
|
|
4a119e869b | ||
|
|
20127ba115 | ||
|
|
3d6344073d | ||
|
|
7dd98b717b | ||
|
|
0ce5667444 | ||
|
|
b03414e643 | ||
|
|
7a67de2d72 | ||
|
|
300bf58715 | ||
|
|
b2bd0ddc50 | ||
|
|
a3d847b05c | ||
|
|
d529d0672d | ||
|
|
f98a5e1119 | ||
|
|
6ec0b09139 | ||
|
|
53691fc95a | ||
|
|
3400e2a14d | ||
|
|
d8cc1f7a2c | ||
|
|
2098e910dd | ||
|
|
e5491d6f79 | ||
|
|
a8934a083a | ||
|
|
80e9507e01 | ||
|
|
60d3be5fe2 | ||
|
|
b481cc36d0 | ||
|
|
65c5da8912 | ||
|
|
0a0366e6ca | ||
|
|
84a623e884 | ||
|
|
6b91607b17 | ||
|
|
82fb737ad9 | ||
|
|
eed49e699e | ||
|
|
3cc7afd334 | ||
|
|
bcbfd28234 | ||
|
|
faa47d9691 | ||
|
|
6649561bf3 | ||
|
|
026cda0468 | ||
|
|
64297e5996 | ||
|
|
c517137c0a | ||
|
|
cbfbe0bbbe | ||
|
|
13ca4c6650 | ||
|
|
e8d9e36d62 | ||
|
|
77e4f3c574 | ||
|
|
2bdc06201a | ||
|
|
077ba9624c | ||
|
|
81eb1a1c7c | ||
|
|
1a16fef783 | ||
|
|
027692d5eb | ||
|
|
3a889f7069 | ||
|
|
20d67bd956 | ||
|
|
8d6b6accaf | ||
|
|
ed76b4eb55 | ||
|
|
7613c100d1 | ||
|
|
c52d3412de | ||
|
|
96b6162b52 | ||
|
|
502ed8909b | ||
|
|
8de75dd033 | ||
|
|
74e3668e38 | ||
|
|
2475a9ef92 | ||
|
|
690f54c441 | ||
|
|
71bb0c029e | ||
|
|
ccf890a129 | ||
|
|
a7bfdebddf | ||
|
|
6fc5ca12a3 | ||
|
|
8298452522 | ||
|
|
2559327636 | ||
|
|
ef185ce2c8 | ||
|
|
a04fee5cbd | ||
|
|
e507378244 | ||
|
|
e6be3f85b2 | ||
|
|
cc96e303ce | ||
|
|
e0fcb1f860 | ||
|
|
f5442c431d | ||
|
|
652e5848e5 | ||
|
|
3fa1896316 | ||
|
|
f855ecab11 | ||
|
|
fd26176e7d | ||
|
|
8986f67779 | ||
|
|
42f2d4aca5 | ||
|
|
7116d24a8c | ||
|
|
7f4593be32 | ||
|
|
f47e25e693 | ||
|
|
877184ae97 | ||
|
|
54961ec8ef | ||
|
|
e797971ce5 | ||
|
|
566cca70d8 | ||
|
|
be2d0e2b5d | ||
|
|
692f937ca4 | ||
|
|
11de1ceb65 | ||
|
|
19993b4679 | ||
|
|
9063827782 | ||
|
|
0cc6fa49d7 | ||
|
|
3f3508b668 | ||
|
|
1c3a88daf8 | ||
|
|
92f30bbad9 | ||
|
|
4abf43d85b | ||
|
|
b08f9adb23 | ||
|
|
7a915833bb | ||
|
|
9698b700e6 | ||
|
|
fd944acc5b | ||
|
|
a1309257f5 | ||
|
|
6266dc816d | ||
|
|
83c011a9e4 | ||
|
|
8d1ac81d09 | ||
|
|
d8cd4c9928 | ||
|
|
5caa4fdaa0 | ||
|
|
f22f33564b | ||
|
|
f86d282a47 | ||
|
|
ece1edb80f | ||
|
|
c9c17e19f3 | ||
|
|
40e834e0b8 | ||
|
|
45bd82d031 | ||
|
|
27c1619c3d | ||
|
|
8cfeb85c43 | ||
|
|
491b550ebc | ||
|
|
1a94dfd113 | ||
|
|
bcd9d7ae41 | ||
|
|
98b4353632 | ||
|
|
f071b280d4 | ||
|
|
f7ebaa42fc | ||
|
|
11737c2069 | ||
|
|
1712253e5f | ||
|
|
de8f292fce | ||
|
|
bbe5058131 | ||
|
|
45fc5e3c97 | ||
|
|
5c976815cc | ||
|
|
3ea4b6e6cc | ||
|
|
7b75c0049b | ||
|
|
04bdce55f4 | ||
|
|
2446b1898e | ||
|
|
6f22a2f656 | ||
|
|
e307a84863 | ||
|
|
2dd27f25cb | ||
|
|
e402c0e3b4 | ||
|
|
2721c8582a | ||
|
|
43c8b7a712 | ||
|
|
f473b85acd | ||
|
|
02cd84c39a | ||
|
|
46d17d6c64 | ||
|
|
10ad536491 | ||
|
|
ccabc1a7a7 | ||
|
|
8e262e4da8 | ||
|
|
79dea9d901 | ||
|
|
2f650bbef8 | ||
|
|
021e67ca71 | ||
|
|
87ae024280 | ||
|
|
5092429557 | ||
|
|
dc691199f5 | ||
|
|
1662c391f0 | ||
|
|
08aefbc115 | ||
|
|
fb6342daa9 | ||
|
|
4e7adcc9ee | ||
|
|
aa4b3d8a24 | ||
|
|
f3bc459b6e | ||
|
|
87cab60b01 | ||
|
|
08ab73caf8 | ||
|
|
675761c81e | ||
|
|
18e15c6da6 | ||
|
|
e1f77e2e17 | ||
|
|
4ef388b2dc | ||
|
|
031485232b | ||
|
|
c0debefaf6 | ||
|
|
bbebe5f201 | ||
|
|
ac9cb22fee | ||
|
|
5e281ce2e6 | ||
|
|
9ea5b7a424 | ||
|
|
e0b83fad4c | ||
|
|
7191b9010d | ||
|
|
fb3428ed37 | ||
|
|
444ad297da | ||
|
|
f46df421a7 | ||
|
|
98a2e12090 | ||
|
|
36bfa8645e | ||
|
|
56e71d7f6c | ||
|
|
e0d172615b | ||
|
|
bde52b13d4 | ||
|
|
b273d91512 | ||
|
|
1fbe76a607 | ||
|
|
6ee7316130 | ||
|
|
51802f46bb | ||
|
|
d430444424 | ||
|
|
17fff6c805 | ||
|
|
a33f6e8416 | ||
|
|
d157649069 | ||
|
|
77bbb9f7a7 | ||
|
|
996b5177d9 | ||
|
|
ab9a3ba970 | ||
|
|
87c1f0ab10 | ||
|
|
dcea1d88e5 | ||
|
|
cc481e20d3 | ||
|
|
4d141a8f68 | ||
|
|
cb32c81d1b | ||
|
|
64f327fdef | ||
|
|
902d6112c3 | ||
|
|
f71e3b9151 | ||
|
|
dd7e1520c5 | ||
|
|
97553de299 | ||
|
|
c80ab8b200 | ||
|
|
85c4ddce39 | ||
|
|
1caa860f8e | ||
|
|
7181cc41af | ||
|
|
959b8c320d | ||
|
|
96fd0432ff | ||
|
|
4c73a03f57 |
8
.git-blame-ignore-revs
Normal file
8
.git-blame-ignore-revs
Normal file
@@ -0,0 +1,8 @@
|
||||
# Exclude these commits from git blame (e.g. mass reformatting).
|
||||
# These are ignored by GitHub automatically.
|
||||
# To enable this locally, run:
|
||||
#
|
||||
# git config blame.ignoreRevsFile .git-blame-ignore-revs
|
||||
|
||||
3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161)
|
||||
ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -1,3 +1,10 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
# Web standards updates
|
||||
/web/STANDARDS.md @raunakab @Weves
|
||||
|
||||
# Agent context files
|
||||
/CLAUDE.md.template @Weves
|
||||
/AGENTS.md.template @Weves
|
||||
|
||||
@@ -7,12 +7,6 @@ inputs:
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Compute requirements hash
|
||||
id: req-hash
|
||||
shell: bash
|
||||
@@ -28,6 +22,8 @@ runs:
|
||||
done <<< "$REQUIREMENTS"
|
||||
echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# NOTE: This comes before Setup uv since clean-ups run in reverse chronological order
|
||||
# such that Setup uv's prune-cache is able to prune the cache before we upload.
|
||||
- name: Cache uv cache directory
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
@@ -36,6 +32,14 @@ runs:
|
||||
restore-keys: |
|
||||
${{ runner.os }}-uv-
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -1,10 +1,10 @@
|
||||
## Description
|
||||
|
||||
[Provide a brief description of the changes in this PR]
|
||||
<!--- Provide a brief description of the changes in this PR --->
|
||||
|
||||
## How Has This Been Tested?
|
||||
|
||||
[Describe the tests you ran to verify your changes]
|
||||
<!--- Describe the tests you ran to verify your changes --->
|
||||
|
||||
## Additional Options
|
||||
|
||||
|
||||
318
.github/workflows/deployment.yml
vendored
318
.github/workflows/deployment.yml
vendored
@@ -6,11 +6,11 @@ on:
|
||||
- "*"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
|
||||
jobs:
|
||||
@@ -20,6 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
outputs:
|
||||
build-desktop: ${{ steps.check.outputs.build-desktop }}
|
||||
build-web: ${{ steps.check.outputs.build-web }}
|
||||
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
|
||||
build-backend: ${{ steps.check.outputs.build-backend }}
|
||||
@@ -29,32 +30,46 @@ jobs:
|
||||
is-beta: ${{ steps.check.outputs.is-beta }}
|
||||
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
|
||||
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
|
||||
is-test-run: ${{ steps.check.outputs.is-test-run }}
|
||||
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
|
||||
short-sha: ${{ steps.check.outputs.short-sha }}
|
||||
steps:
|
||||
- name: Check which components to build and version info
|
||||
id: check
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
run: |
|
||||
TAG="${GITHUB_REF_NAME}"
|
||||
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
|
||||
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
|
||||
SHORT_SHA="${GITHUB_SHA::7}"
|
||||
|
||||
# Initialize all flags to false
|
||||
IS_CLOUD=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_MODEL_SERVER=true
|
||||
IS_NIGHTLY=false
|
||||
IS_VERSION_TAG=false
|
||||
IS_STABLE=false
|
||||
IS_BETA=false
|
||||
IS_STABLE_STANDALONE=false
|
||||
IS_BETA_STANDALONE=false
|
||||
IS_PROD_TAG=false
|
||||
IS_TEST_RUN=false
|
||||
BUILD_DESKTOP=false
|
||||
BUILD_WEB=false
|
||||
BUILD_WEB_CLOUD=false
|
||||
BUILD_BACKEND=true
|
||||
BUILD_MODEL_SERVER=true
|
||||
|
||||
# Determine tag type based on pattern matching (do regex checks once)
|
||||
if [[ "$TAG" == *cloud* ]]; then
|
||||
IS_CLOUD=true
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
fi
|
||||
|
||||
# Version checks (for web - any stable version)
|
||||
if [[ "$TAG" == nightly* ]]; then
|
||||
IS_NIGHTLY=true
|
||||
fi
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then
|
||||
IS_VERSION_TAG=true
|
||||
fi
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
IS_STABLE=true
|
||||
fi
|
||||
@@ -62,15 +77,37 @@ jobs:
|
||||
IS_BETA=true
|
||||
fi
|
||||
|
||||
# Version checks (for backend/model-server - stable version excluding cloud tags)
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
# Determine what to build based on tag type
|
||||
if [[ "$IS_CLOUD" == "true" ]]; then
|
||||
BUILD_WEB_CLOUD=true
|
||||
else
|
||||
BUILD_WEB=true
|
||||
# Skip desktop builds on beta tags and nightly runs
|
||||
if [[ "$IS_BETA" != "true" ]] && [[ "$IS_NIGHTLY" != "true" ]]; then
|
||||
BUILD_DESKTOP=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Standalone version checks (for backend/model-server - version excluding cloud tags)
|
||||
if [[ "$IS_STABLE" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_STABLE_STANDALONE=true
|
||||
fi
|
||||
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
|
||||
if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then
|
||||
IS_BETA_STANDALONE=true
|
||||
fi
|
||||
|
||||
# Determine if this is a production tag
|
||||
# Production tags are: version tags (v1.2.3*) or nightly tags
|
||||
if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then
|
||||
IS_PROD_TAG=true
|
||||
fi
|
||||
|
||||
# Determine if this is a test run (workflow_dispatch on non-production ref)
|
||||
if [[ "$EVENT_NAME" == "workflow_dispatch" ]] && [[ "$IS_PROD_TAG" != "true" ]]; then
|
||||
IS_TEST_RUN=true
|
||||
fi
|
||||
{
|
||||
echo "build-desktop=$BUILD_DESKTOP"
|
||||
echo "build-web=$BUILD_WEB"
|
||||
echo "build-web-cloud=$BUILD_WEB_CLOUD"
|
||||
echo "build-backend=$BUILD_BACKEND"
|
||||
@@ -80,7 +117,9 @@ jobs:
|
||||
echo "is-beta=$IS_BETA"
|
||||
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
|
||||
echo "is-beta-standalone=$IS_BETA_STANDALONE"
|
||||
echo "is-test-run=$IS_TEST_RUN"
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
echo "short-sha=$SHORT_SHA"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
check-version-tag:
|
||||
@@ -95,8 +134,9 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
enable-cache: false
|
||||
|
||||
@@ -124,6 +164,138 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- platform: "macos-latest" # Build a universal image for macOS.
|
||||
args: "--target universal-apple-darwin"
|
||||
- platform: "ubuntu-24.04"
|
||||
args: "--bundles deb,rpm"
|
||||
- platform: "ubuntu-24.04-arm" # Only available in public repos.
|
||||
args: "--bundles deb,rpm"
|
||||
- platform: "windows-latest"
|
||||
args: ""
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
build-essential \
|
||||
libglib2.0-dev \
|
||||
libgirepository1.0-dev \
|
||||
libgtk-3-dev \
|
||||
libjavascriptcoregtk-4.1-dev \
|
||||
libwebkit2gtk-4.1-dev \
|
||||
libayatana-appindicator3-dev \
|
||||
gobject-introspection \
|
||||
pkg-config \
|
||||
curl \
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
|
||||
- name: install Rust stable
|
||||
uses: dtolnay/rust-toolchain@6d9817901c499d6b02debbb57edb38d33daa680b # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
# Those targets are only used on macos runners so it's in an `if` to slightly speed up windows and linux builds.
|
||||
targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }}
|
||||
|
||||
- name: install frontend dependencies
|
||||
working-directory: ./desktop
|
||||
run: npm install
|
||||
|
||||
- name: Inject version (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
working-directory: ./desktop
|
||||
env:
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
VERSION="0.0.0-dev+${SHORT_SHA}"
|
||||
else
|
||||
VERSION="${GITHUB_REF_NAME#v}"
|
||||
fi
|
||||
echo "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
sed "s/^version = .*/version = \"$VERSION\"/" src-tauri/Cargo.toml > src-tauri/Cargo.toml.tmp
|
||||
mv src-tauri/Cargo.toml.tmp src-tauri/Cargo.toml
|
||||
|
||||
# Update tauri.conf.json
|
||||
jq --arg v "$VERSION" '.version = $v' src-tauri/tauri.conf.json > src-tauri/tauri.conf.json.tmp
|
||||
mv src-tauri/tauri.conf.json.tmp src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
jq --arg v "$VERSION" '.version = $v' package.json > package.json.tmp
|
||||
mv package.json.tmp package.json
|
||||
|
||||
echo "Versions set to: $VERSION"
|
||||
|
||||
- name: Inject version (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
working-directory: ./desktop
|
||||
shell: pwsh
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
run: |
|
||||
# Windows MSI requires numeric-only build metadata, so we skip the SHA suffix
|
||||
if ($env:IS_TEST_RUN -eq "true") {
|
||||
$VERSION = "0.0.0"
|
||||
} else {
|
||||
# Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility
|
||||
$VERSION = "$env:GITHUB_REF_NAME" -replace '^v', '' -replace '-.*$', ''
|
||||
}
|
||||
Write-Host "Injecting version: $VERSION"
|
||||
|
||||
# Update Cargo.toml
|
||||
$cargo = Get-Content src-tauri/Cargo.toml -Raw
|
||||
$cargo = $cargo -replace '(?m)^version = .*', "version = `"$VERSION`""
|
||||
Set-Content src-tauri/Cargo.toml $cargo -NoNewline
|
||||
|
||||
# Update tauri.conf.json
|
||||
$json = Get-Content src-tauri/tauri.conf.json | ConvertFrom-Json
|
||||
$json.version = $VERSION
|
||||
$json | ConvertTo-Json -Depth 100 | Set-Content src-tauri/tauri.conf.json
|
||||
|
||||
# Update package.json
|
||||
$pkg = Get-Content package.json | ConvertFrom-Json
|
||||
$pkg.version = $VERSION
|
||||
$pkg | ConvertTo-Json -Depth 100 | Set-Content package.json
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
@@ -147,9 +319,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -179,7 +351,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-arm64:
|
||||
@@ -205,9 +377,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -237,7 +409,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web:
|
||||
@@ -267,20 +439,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -313,9 +485,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -353,7 +525,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-web-cloud-arm64:
|
||||
@@ -379,9 +551,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -419,7 +591,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-web-cloud:
|
||||
@@ -449,17 +621,17 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -492,9 +664,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -523,7 +695,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-arm64:
|
||||
@@ -549,9 +721,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -580,7 +752,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
merge-backend:
|
||||
@@ -610,20 +782,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -657,9 +829,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -692,7 +864,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -721,9 +893,9 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
|
||||
@@ -756,7 +928,7 @@ jobs:
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
outputs: type=image,name=${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }}
|
||||
provenance: false
|
||||
sbom: false
|
||||
@@ -788,20 +960,20 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
latest=false
|
||||
tags: |
|
||||
type=raw,value=${{ github.event_name == 'workflow_dispatch' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ github.event_name != 'workflow_dispatch' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-stable-standalone == 'true' && 'latest' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }}
|
||||
type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
IMAGE_REPO: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }}
|
||||
META_TAGS: ${{ steps.meta.outputs.tags }}
|
||||
@@ -834,7 +1006,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -874,7 +1046,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:web-cloud-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -919,7 +1091,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:backend-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -961,7 +1133,7 @@ jobs:
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ "${{ needs.determine-builds.outputs.is-test-run }}" == "true" ]; then
|
||||
SCAN_IMAGE="${{ env.RUNS_ON_ECR_CACHE }}:model-server-${{ needs.determine-builds.outputs.sanitized-tag }}"
|
||||
else
|
||||
SCAN_IMAGE="docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}"
|
||||
@@ -980,6 +1152,8 @@ jobs:
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs:
|
||||
- determine-builds
|
||||
- build-desktop
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
- merge-web
|
||||
@@ -992,7 +1166,7 @@ jobs:
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
@@ -1007,6 +1181,9 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
FAILED_JOBS=""
|
||||
if [ "${NEEDS_BUILD_DESKTOP_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-desktop\\n"
|
||||
fi
|
||||
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
|
||||
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
|
||||
fi
|
||||
@@ -1047,6 +1224,7 @@ jobs:
|
||||
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
|
||||
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
|
||||
env:
|
||||
NEEDS_BUILD_DESKTOP_RESULT: ${{ needs.build-desktop.result }}
|
||||
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
|
||||
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
|
||||
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}
|
||||
|
||||
31
.github/workflows/merge-group.yml
vendored
Normal file
31
.github/workflows/merge-group.yml
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Merge Group-Specific
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
|
||||
# There is a similarly named "required" job in pr-integration-tests.yml which runs the actual
|
||||
# integration tests. That job runs on both pull_request and merge_group events, and this job
|
||||
# exists solely to provide a fast-passing check with the same name for branch protection.
|
||||
# The actual tests remain enforced on presubmit (pull_request events).
|
||||
required:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Success
|
||||
run: echo "Success"
|
||||
# This job immediately succeeds to satisfy branch protection rules on merge_group events.
|
||||
# There is a similarly named "playwright-required" job in pr-playwright-tests.yml which runs
|
||||
# the actual playwright tests. That job runs on both pull_request and merge_group events, and
|
||||
# this job exists solely to provide a fast-passing check with the same name for branch protection.
|
||||
# The actual tests remain enforced on presubmit (pull_request events).
|
||||
playwright-required:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Success
|
||||
run: echo "Success"
|
||||
62
.github/workflows/pr-database-tests.yml
vendored
Normal file
62
.github/workflows/pr-database-tests.yml
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
name: Database Tests
|
||||
concurrency:
|
||||
group: Database-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
database-tests:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-database-tests"
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers
|
||||
working-directory: ./deployment/docker_compose
|
||||
run: |
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d \
|
||||
relational_db
|
||||
|
||||
- name: Run Database Tests
|
||||
working-directory: ./backend
|
||||
run: pytest -m alembic tests/integration/tests/migrations/
|
||||
@@ -38,6 +38,8 @@ env:
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }}
|
||||
VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }}
|
||||
|
||||
# Code Interpreter
|
||||
# TODO: debug why this is failing and enable
|
||||
|
||||
412
.github/workflows/pr-helm-chart-testing.yml
vendored
412
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -6,11 +6,11 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -18,225 +18,233 @@ permissions:
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
hdd=256,
|
||||
"run-id=${{ github.run_id }}-helm-chart-check",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@6ec842c01de15ebb84c8627d2744a0c2f2755c9f # ratchet:helm/chart-testing-action@v2.8.0
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
67
.github/workflows/pr-integration-tests.yml
vendored
67
.github/workflows/pr-integration-tests.yml
vendored
@@ -33,6 +33,11 @@ env:
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
|
||||
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
|
||||
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
|
||||
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
|
||||
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
@@ -51,7 +56,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -67,9 +72,14 @@ jobs:
|
||||
all_dirs="[${all_dirs%,}]"
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -122,9 +132,14 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -176,9 +191,14 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -220,7 +240,7 @@ jobs:
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
@@ -290,6 +310,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
@@ -304,7 +325,6 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -347,12 +367,6 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -382,8 +396,6 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
@@ -399,6 +411,11 @@ jobs:
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
|
||||
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
|
||||
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
|
||||
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
|
||||
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
@@ -427,15 +444,16 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
runs-on,
|
||||
runner=8cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-multitenant-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -462,10 +480,10 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -474,7 +492,6 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_multi_tenant
|
||||
@@ -523,8 +540,6 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
|
||||
7
.github/workflows/pr-jest-tests.yml
vendored
7
.github/workflows/pr-jest-tests.yml
vendored
@@ -4,7 +4,14 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
39
.github/workflows/pr-mit-integration-tests.yml
vendored
39
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
# Find all leaf-level directories in both test directories
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
|
||||
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
|
||||
|
||||
# Create JSON array with directory info
|
||||
@@ -65,7 +65,13 @@ jobs:
|
||||
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -119,7 +125,13 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -172,7 +184,13 @@ jobs:
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -214,7 +232,7 @@ jobs:
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
@@ -283,6 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -296,7 +315,6 @@ jobs:
|
||||
api_server \
|
||||
inference_model_server \
|
||||
indexing_model_server \
|
||||
mcp_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker
|
||||
@@ -339,12 +357,6 @@ jobs:
|
||||
}
|
||||
|
||||
wait_for_service "http://localhost:8080/health" "API server"
|
||||
test_dir="${{ matrix.test-dir.path }}"
|
||||
if [ "$test_dir" = "tests/mcp" ]; then
|
||||
wait_for_service "http://localhost:8090/health" "MCP server"
|
||||
else
|
||||
echo "Skipping MCP server wait for non-MCP suite: $test_dir"
|
||||
fi
|
||||
echo "Finished waiting for services."
|
||||
|
||||
- name: Start Mock Services
|
||||
@@ -375,8 +387,6 @@ jobs:
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e MCP_SERVER_HOST=mcp_server \
|
||||
-e MCP_SERVER_PORT=8090 \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
@@ -420,7 +430,6 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
|
||||
required:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
|
||||
35
.github/workflows/pr-playwright-tests.yml
vendored
35
.github/workflows/pr-playwright-tests.yml
vendored
@@ -4,7 +4,14 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -47,7 +54,13 @@ env:
|
||||
|
||||
jobs:
|
||||
build-web-image:
|
||||
runs-on: [runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-web-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -102,7 +115,13 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -157,7 +176,13 @@ jobs:
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -231,14 +256,13 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'npm'
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -447,7 +471,6 @@ jobs:
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
17
.github/workflows/release-devtools.yml
vendored
17
.github/workflows/release-devtools.yml
vendored
@@ -16,21 +16,22 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- {goos: "linux", goarch: "amd64"}
|
||||
- {goos: "linux", goarch: "arm64"}
|
||||
- {goos: "windows", goarch: "amd64"}
|
||||
- {goos: "windows", goarch: "arm64"}
|
||||
- {goos: "darwin", goarch: "amd64"}
|
||||
- {goos: "darwin", goarch: "arm64"}
|
||||
- {goos: "", goarch: ""}
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
|
||||
14
.github/workflows/zizmor.yml
vendored
14
.github/workflows/zizmor.yml
vendored
@@ -21,17 +21,29 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
@@ -8,30 +8,66 @@ repos:
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"backend",
|
||||
"-o",
|
||||
"backend/requirements/default.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"dev",
|
||||
"-o",
|
||||
"backend/requirements/dev.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"ee",
|
||||
"-o",
|
||||
"backend/requirements/ee.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
args:
|
||||
[
|
||||
"--no-emit-project",
|
||||
"--no-default-groups",
|
||||
"--no-hashes",
|
||||
"--extra",
|
||||
"model_server",
|
||||
"-o",
|
||||
"backend/requirements/model_server.txt",
|
||||
]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -39,69 +75,68 @@ repos:
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
files: ^.github/
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# this is a fork which keeps compatibility with black
|
||||
- repo: https://github.com/wimglenn/reorder-python-imports-black
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
- id: reorder-python-imports
|
||||
args: ["--py311-plus", "--application-directories=backend/"]
|
||||
# need to ignore alembic files, since reorder-python-imports gets confused
|
||||
# and thinks that alembic is a local package since there is a folder
|
||||
# in the backend directory called `alembic`
|
||||
exclude: ^backend/alembic/
|
||||
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
args:
|
||||
[
|
||||
"--remove-all-unused-imports",
|
||||
"--remove-unused-variables",
|
||||
"--in-place",
|
||||
"--recursive",
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
@@ -112,9 +147,13 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
# 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json'
|
||||
language: system
|
||||
pass_filenames: false
|
||||
files: ^web/.*\.(ts|tsx)$
|
||||
|
||||
51
.vscode/env_template.txt
vendored
51
.vscode/env_template.txt
vendored
@@ -1,36 +1,45 @@
|
||||
# Copy this file to .env in the .vscode folder
|
||||
# Fill in the <REPLACE THIS> values as needed, it is recommended to set the GEN_AI_API_KEY value to avoid having to set up an LLM in the UI
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to restart the containers which Onyx relies on outside of VSCode/Cursor processes
|
||||
# Copy this file to .env in the .vscode folder.
|
||||
# Fill in the <REPLACE THIS> values as needed; it is recommended to set the
|
||||
# GEN_AI_API_KEY value to avoid having to set up an LLM in the UI.
|
||||
# Also check out onyx/backend/scripts/restart_containers.sh for a script to
|
||||
# restart the containers which Onyx relies on outside of VSCode/Cursor
|
||||
# processes.
|
||||
|
||||
# For local dev, often user Authentication is not needed
|
||||
|
||||
# For local dev, often user Authentication is not needed.
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs model prompts, reasoning, and answer to stdout
|
||||
|
||||
# Always keep these on for Dev.
|
||||
# Logs model prompts, reasoning, and answer to stdout.
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server
|
||||
# for dev.
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
|
||||
# Set these so if you wipe the DB, you don't end up having to go through the UI
|
||||
# every time.
|
||||
GEN_AI_API_KEY=<REPLACE THIS>
|
||||
OPENAI_API_KEY=<REPLACE THIS>
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
|
||||
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper.
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
@@ -40,26 +49,36 @@ PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
# Enable the full set of Danswer Enterprise Edition features.
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you
|
||||
# are using this for local testing/development).
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
S3_AWS_ACCESS_KEY_ID=minioadmin
|
||||
S3_AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
# Show extra/uncommon connectors
|
||||
|
||||
# Show extra/uncommon connectors.
|
||||
SHOW_EXTRA_CONNECTORS=True
|
||||
|
||||
|
||||
# Local langsmith tracing
|
||||
LANGSMITH_TRACING="true"
|
||||
LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
|
||||
LANGSMITH_API_KEY=<REPLACE_THIS>
|
||||
LANGSMITH_PROJECT=<REPLACE_THIS>
|
||||
|
||||
|
||||
# Local Confluence OAuth testing
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_ID=<REPLACE_THIS>
|
||||
# OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET=<REPLACE_THIS>
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
# NEXT_PUBLIC_TEST_ENV=True
|
||||
|
||||
|
||||
# OpenSearch
|
||||
# Arbitrary password is fine for local development.
|
||||
OPENSEARCH_INITIAL_ADMIN_PASSWORD=<REPLACE THIS>
|
||||
|
||||
15
.vscode/launch.template.jsonc
vendored
15
.vscode/launch.template.jsonc
vendored
@@ -512,6 +512,21 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to Codex when working with code in this repository.
|
||||
This file provides guidance to AI agents when working with code in this repository.
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -181,6 +181,286 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `web/tailwind-themes/tailwind.config.js` / `web/src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -295,14 +575,6 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
@@ -184,6 +184,286 @@ web/
|
||||
└── src/lib/ # Utilities & business logic
|
||||
```
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
### Running Migrations
|
||||
@@ -300,14 +580,6 @@ will be tailing their logs to this file.
|
||||
- Token management and rate limiting
|
||||
- Custom prompts and agent actions
|
||||
|
||||
## UI/UX Patterns
|
||||
|
||||
- Tailwind CSS with design system in `web/src/components/ui/`
|
||||
- Radix UI and Headless UI for accessible components
|
||||
- SWR for data fetching and caching
|
||||
- Form validation with react-hook-form
|
||||
- Error handling with popup notifications
|
||||
|
||||
## Creating a Plan
|
||||
When creating a plan in the `plans` directory, make sure to include at least these elements:
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ You will need Docker installed to run these containers.
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose up -d index relational_db cache minio
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
@@ -15,3 +15,4 @@ build/
|
||||
dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
@@ -13,23 +13,10 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
-r /tmp/requirements.txt && \
|
||||
rm -rf ~/.cache/uv /tmp/*.txt
|
||||
|
||||
# Stage for downloading tokenizers
|
||||
FROM base AS tokenizers
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
|
||||
# Stage for downloading Onyx models
|
||||
FROM base AS onyx-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model');"
|
||||
|
||||
# Stage for downloading embedding and reranking models
|
||||
# Stage for downloading embedding models
|
||||
FROM base AS embedding-models
|
||||
RUN python -c "from huggingface_hub import snapshot_download; \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1');"
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1');"
|
||||
|
||||
# Initialize SentenceTransformer to cache the custom architecture
|
||||
RUN python -c "from sentence_transformers import SentenceTransformer; \
|
||||
@@ -54,8 +41,6 @@ RUN groupadd -g 1001 onyx && \
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
COPY --chown=onyx:onyx --from=tokenizers /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=onyx-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -39,7 +39,9 @@ config = context.config
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
|
||||
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
@@ -460,8 +462,49 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
Supports pytest-alembic by checking for a pre-configured connection
|
||||
in context.config.attributes["connection"]. If present, uses that
|
||||
connection/engine directly instead of creating a new async engine.
|
||||
"""
|
||||
# Check if pytest-alembic is providing a connection/engine
|
||||
connectable = context.config.attributes.get("connection", None)
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
logger.info("run_migrations_online starting (pytest-alembic mode).")
|
||||
|
||||
# For pytest-alembic, we use the default schema (public)
|
||||
schema_name = context.config.attributes.get(
|
||||
"schema_name", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
# pytest-alembic passes an Engine, we need to get a connection from it
|
||||
with connectable.connect() as connection:
|
||||
# Set search path for the schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
# Commit the transaction to ensure changes are visible to next migration
|
||||
connection.commit()
|
||||
else:
|
||||
# Normal operation - use async migrations
|
||||
logger.info("run_migrations_online starting.")
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "23957775e5f5"
|
||||
down_revision = "bc9771dccadf"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add last refreshed at mcp server
|
||||
|
||||
Revision ID: 2a391f840e85
|
||||
Revises: 4cebcbc9b2ae
|
||||
Create Date: 2025-12-06 15:19:59.766066
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembi.
|
||||
revision = "2a391f840e85"
|
||||
down_revision = "4cebcbc9b2ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("mcp_server", "last_refreshed_at")
|
||||
46
backend/alembic/versions/2b90f3af54b8_usage_limits.py
Normal file
46
backend/alembic/versions/2b90f3af54b8_usage_limits.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""usage_limits
|
||||
|
||||
Revision ID: 2b90f3af54b8
|
||||
Revises: 9a0296d7421e
|
||||
Create Date: 2026-01-03 16:55:30.449692
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2b90f3af54b8"
|
||||
down_revision = "9a0296d7421e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"tenant_usage",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"window_start", sa.DateTime(timezone=True), nullable=False, index=True
|
||||
),
|
||||
sa.Column("llm_cost_cents", sa.Float(), nullable=False, server_default="0.0"),
|
||||
sa.Column("chunks_indexed", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("api_calls", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"non_streaming_api_calls", sa.Integer(), nullable=False, server_default="0"
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("window_start", name="uq_tenant_usage_window"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_tenant_usage_window_start", table_name="tenant_usage")
|
||||
op.drop_table("tenant_usage")
|
||||
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ConfigDict
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.llm_provider_options import (
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add tab_index to tool_call
|
||||
|
||||
Revision ID: 4cebcbc9b2ae
|
||||
Revises: a1b2c3d4e5f6
|
||||
Create Date: 2025-12-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4cebcbc9b2ae"
|
||||
down_revision = "a1b2c3d4e5f6"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool_call", "tab_index")
|
||||
@@ -62,6 +62,11 @@ def upgrade() -> None:
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the temporary table to avoid conflicts if migration runs again
|
||||
# (e.g., during upgrade -> downgrade -> upgrade cycles in tests)
|
||||
op.execute("DROP TABLE IF EXISTS temp_connector_credential")
|
||||
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""backend driven notification details
|
||||
|
||||
Revision ID: 5c3dca366b35
|
||||
Revises: 9087b548dd69
|
||||
Create Date: 2026-01-06 16:03:11.413724
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c3dca366b35"
|
||||
down_revision = "9087b548dd69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column(
|
||||
"title", sa.String(), nullable=False, server_default="New Notification"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column("description", sa.String(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "title")
|
||||
op.drop_column("notification", "description")
|
||||
@@ -0,0 +1,75 @@
|
||||
"""nullify_default_task_prompt
|
||||
|
||||
Revision ID: 699221885109
|
||||
Revises: 7e490836d179
|
||||
Create Date: 2025-12-30 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "699221885109"
|
||||
down_revision = "7e490836d179"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make task_prompt column nullable
|
||||
# Note: The model had nullable=True but the DB column was NOT NULL until this point
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set task_prompt to NULL for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = NULL
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore task_prompt to empty string for the default persona
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE id = :persona_id AND task_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
# Set any remaining NULL task_prompts to empty string before making non-nullable
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET task_prompt = ''
|
||||
WHERE task_prompt IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Revert task_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"task_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""add image generation config table
|
||||
|
||||
Revision ID: 7206234e012a
|
||||
Revises: 699221885109
|
||||
Create Date: 2025-12-21 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7206234e012a"
|
||||
down_revision = "699221885109"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"image_generation_config",
|
||||
sa.Column("image_provider_id", sa.String(), primary_key=True),
|
||||
sa.Column("model_configuration_id", sa.Integer(), nullable=False),
|
||||
sa.Column("is_default", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["model_configuration_id"],
|
||||
["model_configuration.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_is_default",
|
||||
"image_generation_config",
|
||||
["is_default"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
"image_generation_config",
|
||||
["model_configuration_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_model_configuration_id",
|
||||
table_name="image_generation_config",
|
||||
)
|
||||
op.drop_index(
|
||||
"ix_image_generation_config_is_default", table_name="image_generation_config"
|
||||
)
|
||||
op.drop_table("image_generation_config")
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from onyx.llm.llm_provider_options import (
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
fetch_model_names_for_provider_as_set,
|
||||
fetch_visible_model_names_for_provider_as_set,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""nullify_default_system_prompt
|
||||
|
||||
Revision ID: 7e490836d179
|
||||
Revises: c1d2e3f4a5b6
|
||||
Create Date: 2025-12-29 16:54:36.635574
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7e490836d179"
|
||||
down_revision = "c1d2e3f4a5b6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# This is the default system prompt from the previous migration (87c52ec39f84)
|
||||
# ruff: noqa: E501, W605 start
|
||||
PREVIOUS_DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make system_prompt column nullable (model already has nullable=True but DB doesn't)
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Set system_prompt to NULL where it matches the previous default
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = NULL
|
||||
WHERE system_prompt = :previous_default
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore the default system prompt for personas that have NULL
|
||||
# Note: This may restore the prompt to personas that originally had NULL
|
||||
# before this migration, but there's no way to distinguish them
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :previous_default
|
||||
WHERE system_prompt IS NULL
|
||||
"""
|
||||
),
|
||||
{"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT},
|
||||
)
|
||||
|
||||
# Revert system_prompt column to not nullable
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"system_prompt",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -42,13 +42,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
@@ -63,13 +63,13 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"), # type: ignore
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
"""seed_default_image_gen_config
|
||||
|
||||
Revision ID: 9087b548dd69
|
||||
Revises: 2b90f3af54b8
|
||||
Create Date: 2026-01-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9087b548dd69"
|
||||
down_revision = "2b90f3af54b8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for default image generation config
|
||||
# Source: web/src/app/admin/configuration/image-generation/constants.ts
|
||||
IMAGE_PROVIDER_ID = "openai_gpt_image_1"
|
||||
MODEL_NAME = "gpt-image-1"
|
||||
PROVIDER_NAME = "openai"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if image_generation_config table already has records
|
||||
existing_configs = (
|
||||
conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
if existing_configs > 0:
|
||||
# Skip if configs already exist - user may have configured manually
|
||||
return
|
||||
|
||||
# Find the first OpenAI LLM provider
|
||||
openai_provider = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, api_key
|
||||
FROM llm_provider
|
||||
WHERE provider = :provider
|
||||
ORDER BY id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"provider": PROVIDER_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not openai_provider:
|
||||
# No OpenAI provider found - nothing to do
|
||||
return
|
||||
|
||||
source_provider_id, api_key = openai_provider
|
||||
|
||||
# Create new LLM provider for image generation (clone only api_key)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO llm_provider (
|
||||
name, provider, api_key, api_base, api_version,
|
||||
deployment_name, default_model_name, is_public,
|
||||
is_default_provider, is_default_vision_provider, is_auto_mode
|
||||
)
|
||||
VALUES (
|
||||
:name, :provider, :api_key, NULL, NULL,
|
||||
NULL, :default_model_name, :is_public,
|
||||
NULL, NULL, :is_auto_mode
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"name": f"Image Gen - {IMAGE_PROVIDER_ID}",
|
||||
"provider": PROVIDER_NAME,
|
||||
"api_key": api_key,
|
||||
"default_model_name": MODEL_NAME,
|
||||
"is_public": True,
|
||||
"is_auto_mode": False,
|
||||
},
|
||||
)
|
||||
new_provider_id = result.scalar()
|
||||
|
||||
# Create model configuration
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO model_configuration (
|
||||
llm_provider_id, name, is_visible, max_input_tokens,
|
||||
supports_image_input, display_name
|
||||
)
|
||||
VALUES (
|
||||
:llm_provider_id, :name, :is_visible, :max_input_tokens,
|
||||
:supports_image_input, :display_name
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"llm_provider_id": new_provider_id,
|
||||
"name": MODEL_NAME,
|
||||
"is_visible": True,
|
||||
"max_input_tokens": None,
|
||||
"supports_image_input": False,
|
||||
"display_name": None,
|
||||
},
|
||||
)
|
||||
model_config_id = result.scalar()
|
||||
|
||||
# Create image generation config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO image_generation_config (
|
||||
image_provider_id, model_configuration_id, is_default
|
||||
)
|
||||
VALUES (
|
||||
:image_provider_id, :model_configuration_id, :is_default
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"image_provider_id": IMAGE_PROVIDER_ID,
|
||||
"model_configuration_id": model_config_id,
|
||||
"is_default": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the config on downgrade since it's safe to keep around
|
||||
# If we upgrade again, it will be a no-op due to the existing records check
|
||||
pass
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add_is_auto_mode_to_llm_provider
|
||||
|
||||
Revision ID: 9a0296d7421e
|
||||
Revises: 7206234e012a
|
||||
Create Date: 2025-12-17 18:14:29.620981
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9a0296d7421e"
|
||||
down_revision = "7206234e012a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_auto_mode",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "is_auto_mode")
|
||||
@@ -234,6 +234,8 @@ def downgrade() -> None:
|
||||
if "instructions" in columns:
|
||||
op.drop_column("user_project", "instructions")
|
||||
op.execute("ALTER TABLE user_project RENAME TO user_folder")
|
||||
# Update NULL descriptions to empty string before setting NOT NULL constraint
|
||||
op.execute("UPDATE user_folder SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("user_folder", "description", nullable=False)
|
||||
logger.info("Renamed user_project back to user_folder")
|
||||
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""update_default_tool_descriptions
|
||||
|
||||
Revision ID: a01bf2971c5d
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-12-16 15:21:25.656375
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a01bf2971c5d"
|
||||
down_revision = "18b5b2524446"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# new tool descriptions (12/2025)
|
||||
TOOL_DESCRIPTIONS = {
|
||||
"SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.",
|
||||
"ImageGenerationTool": (
|
||||
"The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"OktaProfileTool": (
|
||||
"The Okta Profile Action allows the agent to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
49
backend/alembic/versions/a1b2c3d4e5f6_add_license_table.py
Normal file
49
backend/alembic/versions/a1b2c3d4e5f6_add_license_table.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""add license table
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: a01bf2971c5d
|
||||
Create Date: 2025-12-04 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "a01bf2971c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"license",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("license_data", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Singleton pattern - only ever one row in this table
|
||||
op.create_index(
|
||||
"idx_license_singleton",
|
||||
"license",
|
||||
[sa.text("(true)")],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_license_singleton", table_name="license")
|
||||
op.drop_table("license")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Remove fast_default_model_name from llm_provider
|
||||
|
||||
Revision ID: a2b3c4d5e6f7
|
||||
Revises: 2a391f840e85
|
||||
Create Date: 2024-12-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a2b3c4d5e6f7"
|
||||
down_revision = "2a391f840e85"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("llm_provider", "fast_default_model_name")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""remove userfile related deprecated fields
|
||||
|
||||
Revision ID: a3c1a7904cd0
|
||||
Revises: 5c3dca366b35
|
||||
Create Date: 2026-01-06 13:00:30.634396
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3c1a7904cd0"
|
||||
down_revision = "5c3dca366b35"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user_file", "document_id")
|
||||
op.drop_column("user_file", "document_id_migrated")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
@@ -280,6 +280,14 @@ def downgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
# Recreate the FK constraint that was implicitly dropped when the column was dropped
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_persona",
|
||||
"chat_message",
|
||||
"persona",
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Drop milestone table
|
||||
|
||||
Revision ID: b8c9d0e1f2a3
|
||||
Revises: a2b3c4d5e6f7
|
||||
Create Date: 2025-12-18
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b8c9d0e1f2a3"
|
||||
down_revision = "a2b3c4d5e6f7"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
@@ -0,0 +1,52 @@
|
||||
"""add_deep_research_tool
|
||||
|
||||
Revision ID: c1d2e3f4a5b6
|
||||
Revises: b8c9d0e1f2a3
|
||||
Create Date: 2025-12-18 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c1d2e3f4a5b6"
|
||||
down_revision = "b8c9d0e1f2a3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
|
||||
"""
|
||||
),
|
||||
DEEP_RESEARCH_TOOL,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
|
||||
)
|
||||
@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
|
||||
print(f"File {file_id} not found in PostgreSQL storage.")
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
|
||||
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
|
||||
lobj_id = cast(int, file_record.lobj_oid)
|
||||
file_metadata = cast(Any, file_record.file_metadata)
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
else:
|
||||
# Convert other types to dict if possible, otherwise None
|
||||
try:
|
||||
file_metadata = dict(file_record.file_metadata) # type: ignore
|
||||
file_metadata = dict(file_record.file_metadata)
|
||||
except (TypeError, ValueError):
|
||||
file_metadata = None
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
|
||||
revision = "e209dc5a8156"
|
||||
down_revision = "48d14957fe80"
|
||||
branch_labels = None # type: ignore
|
||||
depends_on = None # type: ignore
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import ( # type: ignore[import-untyped]
|
||||
from onyx.db.enums import (
|
||||
MCPTransport,
|
||||
MCPAuthenticationType,
|
||||
MCPAuthenticationPerformer,
|
||||
|
||||
@@ -20,7 +20,9 @@ config = context.config
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
# disable_existing_loggers=False prevents breaking pytest's caplog fixture
|
||||
# See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
@@ -82,9 +84,9 @@ def run_migrations_offline() -> None:
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
) # type: ignore
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -108,9 +110,24 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
Supports pytest-alembic by checking for a pre-configured connection
|
||||
in context.config.attributes["connection"]. If present, uses that
|
||||
connection/engine directly instead of creating a new async engine.
|
||||
"""
|
||||
# Check if pytest-alembic is providing a connection/engine
|
||||
connectable = context.config.attributes.get("connection", None)
|
||||
|
||||
if connectable is not None:
|
||||
# pytest-alembic is providing an engine - use it directly
|
||||
with connectable.connect() as connection:
|
||||
do_run_migrations(connection)
|
||||
# Commit to ensure changes are visible to next migration
|
||||
connection.commit()
|
||||
else:
|
||||
# Normal operation - use async migrations
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -111,10 +111,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY")
|
||||
ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY")
|
||||
COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
|
||||
@@ -118,6 +118,6 @@ def fetch_document_sets(
|
||||
.all()
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs))
|
||||
|
||||
return document_set_with_cc_pairs
|
||||
|
||||
278
backend/ee/onyx/db/license.py
Normal file
278
backend/ee/onyx/db/license.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_license(db_session: Session) -> License | None:
|
||||
"""
|
||||
Get the current license (singleton pattern - only one row).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
License object if exists, None otherwise
|
||||
"""
|
||||
return db_session.execute(select(License)).scalars().first()
|
||||
|
||||
|
||||
def upsert_license(db_session: Session, license_data: str) -> License:
|
||||
"""
|
||||
Insert or update the license (singleton pattern).
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
license_data: Base64-encoded signed license blob
|
||||
|
||||
Returns:
|
||||
The created or updated License object
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
|
||||
if existing:
|
||||
existing.license_data = license_data
|
||||
db_session.commit()
|
||||
db_session.refresh(existing)
|
||||
logger.info("License updated")
|
||||
return existing
|
||||
|
||||
new_license = License(license_data=license_data)
|
||||
db_session.add(new_license)
|
||||
db_session.commit()
|
||||
db_session.refresh(new_license)
|
||||
logger.info("License created")
|
||||
return new_license
|
||||
|
||||
|
||||
def delete_license(db_session: Session) -> bool:
|
||||
"""
|
||||
Delete the current license.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
True if deleted, False if no license existed
|
||||
"""
|
||||
existing = get_license(db_session)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
logger.info("License deleted")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Seat Counting
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Redis Cache Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
def update_license_cache(
|
||||
payload: LicensePayload,
|
||||
source: LicenseSource | None = None,
|
||||
grace_period_end: datetime | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
2. Caching avoids repeated DB + crypto verification on every request
|
||||
3. Status enforcement happens at the feature level, not here
|
||||
|
||||
Args:
|
||||
payload: Verified license payload
|
||||
source: How the license was obtained
|
||||
grace_period_end: Optional grace period end time
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
The cached LicenseMetadata
|
||||
"""
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id=payload.tenant_id,
|
||||
organization_name=payload.organization_name,
|
||||
seats=payload.seats,
|
||||
used_seats=used_seats,
|
||||
plan_type=payload.plan_type,
|
||||
issued_at=payload.issued_at,
|
||||
expires_at=payload.expires_at,
|
||||
grace_period_end=grace_period_end,
|
||||
status=status,
|
||||
source=source,
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.setex(
|
||||
LICENSE_METADATA_KEY,
|
||||
LICENSE_CACHE_TTL_SECONDS,
|
||||
metadata.model_dump_json(),
|
||||
)
|
||||
|
||||
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
|
||||
return metadata
|
||||
|
||||
|
||||
def refresh_license_cache(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Refresh the license cache from the database.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
|
||||
license_record = get_license(db_session)
|
||||
if not license_record:
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license during cache refresh: {e}")
|
||||
invalidate_license_cache(tenant_id)
|
||||
return None
|
||||
|
||||
|
||||
def get_license_metadata(
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata, using cache if available.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
LicenseMetadata if license exists, None otherwise
|
||||
"""
|
||||
# Try cache first
|
||||
cached = get_cached_license_metadata(tenant_id)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
@@ -34,6 +34,7 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
|
||||
@@ -14,6 +14,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
@@ -139,6 +140,8 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
|
||||
include_router_with_global_prefix_prepended(application, usage_export_router)
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -21,8 +21,9 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics")
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import List
|
||||
|
||||
@@ -23,6 +24,12 @@ class NavigationItem(BaseModel):
|
||||
return instance
|
||||
|
||||
|
||||
class LogoDisplayStyle(str, Enum):
|
||||
LOGO_AND_NAME = "logo_and_name"
|
||||
LOGO_ONLY = "logo_only"
|
||||
NAME_ONLY = "name_only"
|
||||
|
||||
|
||||
class EnterpriseSettings(BaseModel):
|
||||
"""General settings that only apply to the Enterprise Edition of Onyx
|
||||
|
||||
@@ -31,6 +38,7 @@ class EnterpriseSettings(BaseModel):
|
||||
application_name: str | None = None
|
||||
use_custom_logo: bool = False
|
||||
use_custom_logotype: bool = False
|
||||
logo_display_style: LogoDisplayStyle | None = None
|
||||
|
||||
# custom navigation
|
||||
custom_nav_items: List[NavigationItem] = Field(default_factory=list)
|
||||
@@ -42,6 +50,9 @@ class EnterpriseSettings(BaseModel):
|
||||
custom_popup_header: str | None = None
|
||||
custom_popup_content: str | None = None
|
||||
enable_consent_screen: bool | None = None
|
||||
consent_screen_prompt: str | None = None
|
||||
show_first_visit_notice: bool | None = None
|
||||
custom_greeting_message: str | None = None
|
||||
|
||||
def check_validity(self) -> None:
|
||||
return
|
||||
|
||||
246
backend/ee/onyx/server/license/api.py
Normal file
246
backend/ee/onyx/server/license/api.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""License API endpoints."""
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.db.license import delete_license as db_delete_license
|
||||
from ee.onyx.db.license import get_license_metadata
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseResponse
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import LicenseStatusResponse
|
||||
from ee.onyx.server.license.models import LicenseUploadResponse
|
||||
from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/license")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_license_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""Get current license status and seat usage."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/seats")
|
||||
async def get_seat_usage(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SeatUsageResponse:
|
||||
"""Get detailed seat usage information."""
|
||||
metadata = get_license_metadata(db_session)
|
||||
|
||||
if not metadata:
|
||||
return SeatUsageResponse(
|
||||
total_seats=0,
|
||||
used_seats=0,
|
||||
available_seats=0,
|
||||
)
|
||||
|
||||
return SeatUsageResponse(
|
||||
total_seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
available_seats=max(0, metadata.seats - metadata.used_seats),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/fetch")
|
||||
async def fetch_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseResponse:
|
||||
"""
|
||||
Fetch license from control plane.
|
||||
Used after Stripe checkout completion to retrieve the new license.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to generate data plane token: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Authentication configuration error"
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
|
||||
response = requests.get(url, headers=headers, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
if not isinstance(data, dict) or "license" not in data:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Invalid response from control plane"
|
||||
)
|
||||
|
||||
license_data = data["license"]
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license found")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Verify the fetched license is for this tenant
|
||||
if payload.tenant_id != tenant_id:
|
||||
logger.error(
|
||||
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License tenant ID mismatch - control plane returned wrong license",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache atomically
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseResponse(success=True, license=payload)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else 502
|
||||
logger.error(f"Control plane returned error: {status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status_code,
|
||||
detail="Failed to fetch license from control plane",
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"License verification failed: {type(e).__name__}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except requests.RequestException:
|
||||
logger.exception("Failed to fetch license from control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_license(
|
||||
license_file: UploadFile = File(...),
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseUploadResponse:
|
||||
"""
|
||||
Upload a license file manually.
|
||||
Used for air-gapped deployments where control plane is not accessible.
|
||||
"""
|
||||
try:
|
||||
content = await license_file.read()
|
||||
license_data = content.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if payload.tenant_id != tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
|
||||
)
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
try:
|
||||
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
|
||||
except Exception as cache_error:
|
||||
# Log but don't fail - DB is source of truth, cache will refresh on next read
|
||||
logger.warning(f"Failed to update license cache: {cache_error}")
|
||||
|
||||
return LicenseUploadResponse(
|
||||
success=True,
|
||||
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh_license_cache_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> LicenseStatusResponse:
|
||||
"""
|
||||
Force refresh the license cache from the database.
|
||||
Useful after manual database changes or to verify license validity.
|
||||
"""
|
||||
metadata = refresh_license_cache(db_session)
|
||||
|
||||
if not metadata:
|
||||
return LicenseStatusResponse(has_license=False)
|
||||
|
||||
return LicenseStatusResponse(
|
||||
has_license=True,
|
||||
seats=metadata.seats,
|
||||
used_seats=metadata.used_seats,
|
||||
plan_type=metadata.plan_type,
|
||||
issued_at=metadata.issued_at,
|
||||
expires_at=metadata.expires_at,
|
||||
grace_period_end=metadata.grace_period_end,
|
||||
status=metadata.status,
|
||||
source=metadata.source,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("")
|
||||
async def delete_license(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Delete the current license.
|
||||
Admin only - removes license and invalidates cache.
|
||||
"""
|
||||
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
|
||||
try:
|
||||
invalidate_license_cache()
|
||||
except Exception as cache_error:
|
||||
logger.warning(f"Failed to invalidate license cache: {cache_error}")
|
||||
|
||||
deleted = db_delete_license(db_session)
|
||||
|
||||
return {"deleted": deleted}
|
||||
92
backend/ee/onyx/server/license/models.py
Normal file
92
backend/ee/onyx/server/license/models.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class PlanType(str, Enum):
|
||||
MONTHLY = "monthly"
|
||||
ANNUAL = "annual"
|
||||
|
||||
|
||||
class LicenseSource(str, Enum):
|
||||
AUTO_FETCH = "auto_fetch"
|
||||
MANUAL_UPLOAD = "manual_upload"
|
||||
|
||||
|
||||
class LicensePayload(BaseModel):
|
||||
"""The payload portion of a signed license."""
|
||||
|
||||
version: str
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
seats: int
|
||||
plan_type: PlanType
|
||||
billing_cycle: str | None = None
|
||||
grace_period_days: int = 30
|
||||
stripe_subscription_id: str | None = None
|
||||
stripe_customer_id: str | None = None
|
||||
|
||||
|
||||
class LicenseData(BaseModel):
|
||||
"""Full signed license structure."""
|
||||
|
||||
payload: LicensePayload
|
||||
signature: str
|
||||
|
||||
|
||||
class LicenseMetadata(BaseModel):
|
||||
"""Cached license metadata stored in Redis."""
|
||||
|
||||
tenant_id: str
|
||||
organization_name: str | None = None
|
||||
seats: int
|
||||
used_seats: int
|
||||
plan_type: PlanType
|
||||
issued_at: datetime
|
||||
expires_at: datetime
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus
|
||||
source: LicenseSource | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class LicenseStatusResponse(BaseModel):
|
||||
"""Response for license status API."""
|
||||
|
||||
has_license: bool
|
||||
seats: int = 0
|
||||
used_seats: int = 0
|
||||
plan_type: PlanType | None = None
|
||||
issued_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
grace_period_end: datetime | None = None
|
||||
status: ApplicationStatus | None = None
|
||||
source: LicenseSource | None = None
|
||||
|
||||
|
||||
class LicenseResponse(BaseModel):
|
||||
"""Response after license fetch/upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
license: LicensePayload | None = None
|
||||
|
||||
|
||||
class LicenseUploadResponse(BaseModel):
|
||||
"""Response after license upload."""
|
||||
|
||||
success: bool
|
||||
message: str | None = None
|
||||
|
||||
|
||||
class SeatUsageResponse(BaseModel):
|
||||
"""Response for seat usage API."""
|
||||
|
||||
total_seats: int
|
||||
used_seats: int
|
||||
available_seats: int
|
||||
@@ -20,7 +20,7 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -100,14 +100,12 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -158,7 +156,7 @@ def handle_send_message_simple_with_history(
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -205,14 +203,12 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -54,9 +54,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
@@ -76,8 +73,6 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -294,7 +295,7 @@ def list_all_query_history_exports(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export")
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -340,7 +341,7 @@ def start_query_history_export(
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status")
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
@@ -374,7 +375,7 @@ def get_query_history_export_status(
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download")
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
||||
92
backend/ee/onyx/server/tenant_usage_limits.py
Normal file
92
backend/ee/onyx/server/tenant_usage_limits.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tenant_id to their specific limit overrides.
|
||||
Returns empty dict on any error (falls back to defaults).
|
||||
"""
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides"
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
tenant_overrides = response.json()
|
||||
|
||||
# Parse each tenant's overrides
|
||||
result: dict[str, TenantUsageLimitOverrides] = {}
|
||||
for override_data in tenant_overrides:
|
||||
tenant_id = override_data["tenant_id"]
|
||||
try:
|
||||
result[tenant_id] = TenantUsageLimitOverrides(**override_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
return overrides
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
tenant_id: str,
|
||||
) -> TenantUsageLimitOverrides | None:
|
||||
"""
|
||||
Get the usage limit overrides for a specific tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to look up
|
||||
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
@@ -10,10 +9,7 @@ from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
@@ -25,11 +21,18 @@ from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -37,15 +40,25 @@ from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import get_openai_model_names
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations
|
||||
from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
get_recommendations,
|
||||
)
|
||||
from onyx.llm.well_known_providers.llm_provider_options import (
|
||||
model_configurations_for_provider,
|
||||
)
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
@@ -53,7 +66,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
@@ -262,61 +275,173 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in get_anthropic_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
def _build_model_configuration_upsert_requests(
|
||||
provider_name: str,
|
||||
recommendations: LLMRecommendations,
|
||||
) -> list[ModelConfigurationUpsertRequest]:
|
||||
model_configurations = model_configurations_for_provider(
|
||||
provider_name, recommendations
|
||||
)
|
||||
return [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
)
|
||||
for model_configuration in model_configurations
|
||||
]
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
"""Configure default LLM providers using recommended-models.json for model selection."""
|
||||
# Load recommendations from JSON config
|
||||
recommendations = get_recommendations()
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
|
||||
# Configure OpenAI provider
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gpt-5.2"
|
||||
|
||||
openai_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for model_name in get_openai_model_names()
|
||||
],
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
full_provider = upsert_llm_provider(openai_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
create_default_image_gen_config_from_api_key(
|
||||
db_session, OPENAI_DEFAULT_API_KEY
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
logger.error(f"Failed to create default image gen config: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
logger.info(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
# Configure Anthropic provider
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(ANTHROPIC_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {ANTHROPIC_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = (
|
||||
default_model.name if default_model else "claude-sonnet-4-5"
|
||||
)
|
||||
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
# Configure Vertex AI provider
|
||||
if VERTEXAI_DEFAULT_CREDENTIALS:
|
||||
default_model = recommendations.get_default_model(VERTEXAI_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {VERTEXAI_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "gemini-2.5-pro"
|
||||
|
||||
# Vertex AI uses custom_config for credentials and location
|
||||
custom_config = {
|
||||
VERTEX_CREDENTIALS_FILE_KWARG: VERTEXAI_DEFAULT_CREDENTIALS,
|
||||
VERTEX_LOCATION_KWARG: VERTEXAI_DEFAULT_LOCATION,
|
||||
}
|
||||
|
||||
vertexai_provider = LLMProviderUpsertRequest(
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
)
|
||||
|
||||
# Configure OpenRouter provider
|
||||
if OPENROUTER_DEFAULT_API_KEY:
|
||||
default_model = recommendations.get_default_model(OPENROUTER_PROVIDER_NAME)
|
||||
if default_model is None:
|
||||
logger.error(
|
||||
f"No default model found for {OPENROUTER_PROVIDER_NAME} in recommendations"
|
||||
)
|
||||
default_model_name = default_model.name if default_model else "z-ai/glm-4.7"
|
||||
|
||||
# For OpenRouter, we use the visible models from recommendations as model_configurations
|
||||
# since OpenRouter models are dynamic (fetched from their API)
|
||||
visible_models = recommendations.get_visible_models(OPENROUTER_PROVIDER_NAME)
|
||||
model_configurations = [
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model.name,
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=model.display_name,
|
||||
)
|
||||
for model in visible_models
|
||||
]
|
||||
|
||||
openrouter_provider = LLMProviderUpsertRequest(
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
)
|
||||
|
||||
# Configure Cohere embedding provider
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
@@ -562,17 +687,11 @@ async def assign_tenant_to_user(
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=email,
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -249,6 +249,17 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
)
|
||||
raise
|
||||
|
||||
# Remove from invited users list since they've accepted
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
invited_users = get_invited_users()
|
||||
if email in invited_users:
|
||||
invited_users.remove(email)
|
||||
write_invited_users(invited_users)
|
||||
logger.info(f"Removed {email} from invited users list after acceptance")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -16,8 +16,9 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
47
backend/ee/onyx/server/usage_limits.py
Normal file
47
backend/ee/onyx/server/usage_limits.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
"""
|
||||
Determine if a tenant is currently on a trial subscription.
|
||||
|
||||
In multi-tenant mode, we fetch billing information from the control plane
|
||||
to determine if the tenant has an active trial.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return False
|
||||
|
||||
try:
|
||||
billing_info = fetch_billing_information(tenant_id)
|
||||
|
||||
# If not subscribed at all, check if we have trial information
|
||||
if isinstance(billing_info, SubscriptionStatusResponse):
|
||||
# No subscription means they're likely on trial (new tenant)
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch billing info for trial check: {e}")
|
||||
# Default to trial limits on error (more restrictive = safer)
|
||||
return True
|
||||
@@ -21,11 +21,12 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
|
||||
126
backend/ee/onyx/utils/license.py
Normal file
126
backend/ee/onyx/utils/license.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""RSA-4096 license signature verification utilities."""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
|
||||
from ee.onyx.server.license.models import LicenseData
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
|
||||
|
||||
def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
"""
|
||||
Verify RSA-4096 signature and return payload if valid.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded JSON containing payload and signature
|
||||
|
||||
Returns:
|
||||
LicensePayload if signature is valid
|
||||
|
||||
Raises:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
def get_license_status(
|
||||
payload: LicensePayload,
|
||||
grace_period_end: datetime | None = None,
|
||||
) -> ApplicationStatus:
|
||||
"""
|
||||
Determine current license status based on expiry.
|
||||
|
||||
Args:
|
||||
payload: The verified license payload
|
||||
grace_period_end: Optional grace period end datetime
|
||||
|
||||
Returns:
|
||||
ApplicationStatus indicating current license state
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Check if grace period has expired
|
||||
if grace_period_end and now > grace_period_end:
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# Check if license has expired
|
||||
if now > payload.expires_at:
|
||||
if grace_period_end and now <= grace_period_end:
|
||||
return ApplicationStatus.GRACE_PERIOD
|
||||
return ApplicationStatus.GATED_ACCESS
|
||||
|
||||
# License is valid
|
||||
return ApplicationStatus.ACTIVE
|
||||
|
||||
|
||||
def is_license_valid(payload: LicensePayload) -> bool:
|
||||
"""Check if a license is currently valid (not expired)."""
|
||||
now = datetime.now(timezone.utc)
|
||||
return now <= payload.expires_at
|
||||
@@ -1,5 +1,4 @@
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
|
||||
@@ -1,562 +0,0 @@
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from setfit import SetFitModel # type: ignore
|
||||
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
_INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
def get_local_connector_classifier(
|
||||
model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
) -> ConnectorClassifier:
|
||||
global _CONNECTOR_CLASSIFIER_MODEL
|
||||
if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
_CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
local_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
# The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# unmodified distilbert tokenizer.
|
||||
_INTENT_TOKENIZER = cast(
|
||||
PreTrainedTokenizer,
|
||||
AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
tag: str | None = INTENT_MODEL_TAG,
|
||||
) -> HybridClassifier:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(f"Loading model from local cache: {model_name_or_path}")
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
logger.notice(f"Loaded model from local cache: {local_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> "SetFitModel":
|
||||
from setfit import SetFitModel
|
||||
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load content information model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
connector_token_end_id: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
token and then the user query.
|
||||
"""
|
||||
|
||||
input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
for connector in connectors:
|
||||
connector_token_ids = tokenizer(
|
||||
connector,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
query_token_ids = tokenizer(
|
||||
query,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
input_ids = torch.cat(
|
||||
(
|
||||
input_ids,
|
||||
query_token_ids["input_ids"].squeeze(dim=0),
|
||||
torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
def warm_up_connector_classifier_model() -> None:
|
||||
logger.info(
|
||||
f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
)
|
||||
connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
connector_classifier = get_local_connector_classifier()
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
["GitHub"],
|
||||
"onyx classifier query google doc",
|
||||
connector_classifier_tokenizer,
|
||||
connector_classifier.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(connector_classifier.device)
|
||||
attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
tokens = intent_tokenizer(
|
||||
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
||||
)
|
||||
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
intent_model(
|
||||
query_ids=tokens["input_ids"].to(device),
|
||||
query_mask=tokens["attention_mask"].to(device),
|
||||
)
|
||||
|
||||
|
||||
def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
device = intent_model.device
|
||||
|
||||
outputs = intent_model(
|
||||
query_ids=tokens["input_ids"].to(device),
|
||||
query_mask=tokens["attention_mask"].to(device),
|
||||
)
|
||||
|
||||
token_logits = outputs["token_logits"]
|
||||
intent_logits = outputs["intent_logits"]
|
||||
|
||||
# Move tensors to CPU before applying softmax and converting to numpy
|
||||
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
|
||||
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
|
||||
|
||||
# Extract the probabilities for the positive class (index 1) for each token
|
||||
token_positive_probs = token_probabilities[:, 1].tolist()
|
||||
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
prediction_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
content_classification_predictions = [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
if not len(tokens) == len(is_keyword):
|
||||
raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
if input_ids[0] == tokenizer.cls_token_id:
|
||||
tokens = tokens[1:]
|
||||
is_keyword = is_keyword[1:]
|
||||
|
||||
if input_ids[-1] == tokenizer.sep_token_id:
|
||||
tokens = tokens[:-1]
|
||||
is_keyword = is_keyword[:-1]
|
||||
|
||||
unk_token = tokenizer.unk_token
|
||||
if unk_token in tokens:
|
||||
raise ValueError("Unknown token detected in the input")
|
||||
|
||||
keywords = []
|
||||
current_keyword = ""
|
||||
|
||||
for ind, token in enumerate(tokens):
|
||||
if is_keyword[ind]:
|
||||
if token.startswith("##"):
|
||||
current_keyword += token[2:]
|
||||
else:
|
||||
if current_keyword:
|
||||
keywords.append(current_keyword)
|
||||
current_keyword = token
|
||||
else:
|
||||
# If mispredicted a later token of a keyword, add it to the current keyword
|
||||
# to complete it
|
||||
if current_keyword:
|
||||
if len(current_keyword) > 2 and current_keyword.startswith("##"):
|
||||
current_keyword = current_keyword[2:]
|
||||
|
||||
else:
|
||||
keywords.append(current_keyword)
|
||||
current_keyword = ""
|
||||
|
||||
if current_keyword:
|
||||
keywords.append(current_keyword)
|
||||
|
||||
return keywords
|
||||
|
||||
|
||||
def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
cleaned_words = []
|
||||
for word in keywords:
|
||||
word = word[:-2] if word.endswith("'s") else word
|
||||
word = word.replace("/", " ")
|
||||
word = word.replace("'", "").replace('"', "")
|
||||
cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
|
||||
return cleaned_words
|
||||
|
||||
|
||||
def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
tokenizer = get_connector_classifier_tokenizer()
|
||||
model = get_local_connector_classifier()
|
||||
|
||||
connector_names = req.available_connectors
|
||||
|
||||
input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
connector_names,
|
||||
req.query,
|
||||
tokenizer,
|
||||
model.connector_end_token_id,
|
||||
)
|
||||
input_ids = input_ids.to(model.device)
|
||||
attention_mask = attention_mask.to(model.device)
|
||||
|
||||
global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
if global_confidence.item() < 0.5:
|
||||
return []
|
||||
|
||||
passed_connectors = []
|
||||
|
||||
for i, connector_name in enumerate(connector_names):
|
||||
if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
passed_connectors.append(connector_name)
|
||||
|
||||
return passed_connectors
|
||||
|
||||
|
||||
def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
model_input = tokenizer(
|
||||
intent_req.query, return_tensors="pt", truncation=False, padding=False
|
||||
)
|
||||
|
||||
if len(model_input.input_ids[0]) > 512:
|
||||
# If the user text is too long, assume it is semantic and keep all words
|
||||
return True, intent_req.query.split()
|
||||
|
||||
intent_probs, token_probs = run_inference(model_input)
|
||||
|
||||
is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
|
||||
|
||||
keyword_preds = [
|
||||
token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
|
||||
]
|
||||
|
||||
try:
|
||||
keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to extract keywords for query: {intent_req.query} due to {e}"
|
||||
)
|
||||
# Fallback to keeping all words
|
||||
keywords = intent_req.query.split()
|
||||
|
||||
cleaned_keywords = clean_keywords(keywords)
|
||||
|
||||
return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
@router.post("/connector-classification")
|
||||
async def process_connector_classification_request(
|
||||
classification_request: ConnectorClassificationRequest,
|
||||
) -> ConnectorClassificationResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError(
|
||||
"Indexing model server should not call connector classification endpoint"
|
||||
)
|
||||
|
||||
if len(classification_request.available_connectors) == 0:
|
||||
return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
connectors = run_connector_classification(classification_request)
|
||||
return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
@router.post("/query-analysis")
|
||||
async def process_analysis_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -10,16 +9,13 @@ from fastapi import Request
|
||||
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -27,11 +23,6 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
@@ -42,7 +33,7 @@ def get_embedding_model(
|
||||
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
|
||||
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
|
||||
"""
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
|
||||
"""
|
||||
@@ -87,19 +78,6 @@ def get_embedding_model(
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str,
|
||||
) -> "CrossEncoder":
|
||||
global _RERANK_MODEL
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
|
||||
if _RERANK_MODEL is None:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
_RERANK_MODEL = model
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
ENCODING_RETRIES = 3
|
||||
ENCODING_RETRY_DELAY = 0.1
|
||||
|
||||
@@ -189,16 +167,6 @@ async def embed_text(
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
@@ -254,39 +222,3 @@ async def process_embed_request(
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if rerank_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server reranking endpoint should only be used for local models. "
|
||||
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not rerank_request.documents or not rerank_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing documents or query for reranking"
|
||||
)
|
||||
if not all(rerank_request.documents):
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
# At this point, provider_type is None, so handle local reranking
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Cross-Encoder reranking"
|
||||
)
|
||||
|
||||
5
backend/model_server/legacy/README.md
Normal file
5
backend/model_server/legacy/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
This directory contains code that was useful and may become useful again in the future.
|
||||
|
||||
We stopped using rerankers because the state of the art rerankers are not significantly better than the biencoders and much worse than LLMs which are also capable of acting on a small set of documents for filtering, reranking, etc.
|
||||
|
||||
We stopped using the internal query classifier as that's now offloaded to the LLM which does query expansion so we know ahead of time if it's a keyword or semantic query.
|
||||
0
backend/model_server/legacy/__init__.py
Normal file
0
backend/model_server/legacy/__init__.py
Normal file
573
backend/model_server/legacy/custom_models.py
Normal file
573
backend/model_server/legacy/custom_models.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# from typing import cast
|
||||
# from typing import Optional
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# import numpy as np
|
||||
# import torch
|
||||
# import torch.nn.functional as F
|
||||
# from fastapi import APIRouter
|
||||
# from huggingface_hub import snapshot_download
|
||||
# from pydantic import BaseModel
|
||||
|
||||
# from model_server.constants import MODEL_WARM_UP_STRING
|
||||
# from model_server.legacy.onyx_torch_model import ConnectorClassifier
|
||||
# from model_server.legacy.onyx_torch_model import HybridClassifier
|
||||
# from model_server.utils import simple_log_function_time
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
# from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
# from shared_configs.configs import INDEXING_ONLY
|
||||
# from shared_configs.configs import INTENT_MODEL_TAG
|
||||
# from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
# from shared_configs.model_server_models import IntentRequest
|
||||
# from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from setfit import SetFitModel # type: ignore[import-untyped]
|
||||
# from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
|
||||
|
||||
# INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi" * 50
|
||||
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = 1.0
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = 0.7
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = 4.0
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = 10
|
||||
# INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
|
||||
# INFORMATION_CONTENT_MODEL_TAG: str | None = None
|
||||
|
||||
|
||||
# class ConnectorClassificationRequest(BaseModel):
|
||||
# available_connectors: list[str]
|
||||
# query: str
|
||||
|
||||
|
||||
# class ConnectorClassificationResponse(BaseModel):
|
||||
# connectors: list[str]
|
||||
|
||||
|
||||
# class ContentClassificationPrediction(BaseModel):
|
||||
# predicted_label: int
|
||||
# content_boost_factor: float
|
||||
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
# router = APIRouter(prefix="/custom")
|
||||
|
||||
# _CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
# _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
|
||||
# _INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None
|
||||
# _INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
# _INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None
|
||||
|
||||
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
# def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer":
|
||||
# global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
# from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# if _CONNECTOR_CLASSIFIER_TOKENIZER is None:
|
||||
# # The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# # unmodified distilbert tokenizer.
|
||||
# _CONNECTOR_CLASSIFIER_TOKENIZER = cast(
|
||||
# PreTrainedTokenizer,
|
||||
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
# )
|
||||
# return _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
|
||||
|
||||
# def get_local_connector_classifier(
|
||||
# model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO,
|
||||
# tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG,
|
||||
# ) -> ConnectorClassifier:
|
||||
# global _CONNECTOR_CLASSIFIER_MODEL
|
||||
# if _CONNECTOR_CLASSIFIER_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
# local_path
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.info(f"Downloading model snapshot for {model_name_or_path}")
|
||||
# local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
# _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained(
|
||||
# local_path
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
# return _CONNECTOR_CLASSIFIER_MODEL
|
||||
|
||||
|
||||
# def get_intent_model_tokenizer() -> "PreTrainedTokenizer":
|
||||
# from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
# global _INTENT_TOKENIZER
|
||||
# if _INTENT_TOKENIZER is None:
|
||||
# # The tokenizer details are not uploaded to the HF hub since it's just the
|
||||
# # unmodified distilbert tokenizer.
|
||||
# _INTENT_TOKENIZER = cast(
|
||||
# PreTrainedTokenizer,
|
||||
# AutoTokenizer.from_pretrained("distilbert-base-uncased"),
|
||||
# )
|
||||
# return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
# def get_local_intent_model(
|
||||
# model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
# tag: str | None = INTENT_MODEL_TAG,
|
||||
# ) -> HybridClassifier:
|
||||
# global _INTENT_MODEL
|
||||
# if _INTENT_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# logger.notice(f"Loading model from local cache: {model_name_or_path}")
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
# logger.notice(f"Loaded model from local cache: {local_path}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
# )
|
||||
# _INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
# return _INTENT_MODEL
|
||||
|
||||
|
||||
# def get_local_information_content_model(
|
||||
# model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
# tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
# ) -> "SetFitModel":
|
||||
# from setfit import SetFitModel
|
||||
|
||||
# global _INFORMATION_CONTENT_MODEL
|
||||
# if _INFORMATION_CONTENT_MODEL is None:
|
||||
# try:
|
||||
# # Calculate where the cache should be, then load from local if available
|
||||
# logger.notice(
|
||||
# f"Loading content information model from local cache: {model_name_or_path}"
|
||||
# )
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
# )
|
||||
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
# logger.notice(
|
||||
# f"Loaded content information model from local cache: {local_path}"
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to load content information model directly: {e}")
|
||||
# try:
|
||||
# # Attempt to download the model snapshot
|
||||
# logger.notice(
|
||||
# f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
# )
|
||||
# local_path = snapshot_download(
|
||||
# repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
# )
|
||||
# _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
# except Exception as e:
|
||||
# logger.error(
|
||||
# f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
# )
|
||||
# raise
|
||||
|
||||
# return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
# def tokenize_connector_classification_query(
|
||||
# connectors: list[str],
|
||||
# query: str,
|
||||
# tokenizer: "PreTrainedTokenizer",
|
||||
# connector_token_end_id: int,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models
|
||||
|
||||
# The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end
|
||||
# token and then the user query.
|
||||
# """
|
||||
|
||||
# input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long)
|
||||
|
||||
# for connector in connectors:
|
||||
# connector_token_ids = tokenizer(
|
||||
# connector,
|
||||
# add_special_tokens=False,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
|
||||
# input_ids = torch.cat(
|
||||
# (
|
||||
# input_ids,
|
||||
# connector_token_ids["input_ids"].squeeze(dim=0),
|
||||
# torch.tensor([connector_token_end_id], dtype=torch.long),
|
||||
# ),
|
||||
# dim=-1,
|
||||
# )
|
||||
# query_token_ids = tokenizer(
|
||||
# query,
|
||||
# add_special_tokens=False,
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
|
||||
# input_ids = torch.cat(
|
||||
# (
|
||||
# input_ids,
|
||||
# query_token_ids["input_ids"].squeeze(dim=0),
|
||||
# torch.tensor([tokenizer.sep_token_id], dtype=torch.long),
|
||||
# ),
|
||||
# dim=-1,
|
||||
# )
|
||||
# attention_mask = torch.ones(input_ids.numel(), dtype=torch.long)
|
||||
|
||||
# return input_ids.unsqueeze(0), attention_mask.unsqueeze(0)
|
||||
|
||||
|
||||
# def warm_up_connector_classifier_model() -> None:
|
||||
# logger.info(
|
||||
# f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}"
|
||||
# )
|
||||
# connector_classifier_tokenizer = get_connector_classifier_tokenizer()
|
||||
# connector_classifier = get_local_connector_classifier()
|
||||
|
||||
# input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
# ["GitHub"],
|
||||
# "onyx classifier query google doc",
|
||||
# connector_classifier_tokenizer,
|
||||
# connector_classifier.connector_end_token_id,
|
||||
# )
|
||||
# input_ids = input_ids.to(connector_classifier.device)
|
||||
# attention_mask = attention_mask.to(connector_classifier.device)
|
||||
|
||||
# connector_classifier(input_ids, attention_mask)
|
||||
|
||||
|
||||
# def warm_up_intent_model() -> None:
|
||||
# logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}")
|
||||
# intent_tokenizer = get_intent_model_tokenizer()
|
||||
# tokens = intent_tokenizer(
|
||||
# MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
||||
# )
|
||||
|
||||
# intent_model = get_local_intent_model()
|
||||
# device = intent_model.device
|
||||
# intent_model(
|
||||
# query_ids=tokens["input_ids"].to(device),
|
||||
# query_mask=tokens["attention_mask"].to(device),
|
||||
# )
|
||||
|
||||
|
||||
# def warm_up_information_content_model() -> None:
|
||||
# logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
# information_content_model = get_local_information_content_model()
|
||||
# information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]:
|
||||
# intent_model = get_local_intent_model()
|
||||
# device = intent_model.device
|
||||
|
||||
# outputs = intent_model(
|
||||
# query_ids=tokens["input_ids"].to(device),
|
||||
# query_mask=tokens["attention_mask"].to(device),
|
||||
# )
|
||||
|
||||
# token_logits = outputs["token_logits"]
|
||||
# intent_logits = outputs["intent_logits"]
|
||||
|
||||
# # Move tensors to CPU before applying softmax and converting to numpy
|
||||
# intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
|
||||
# token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
|
||||
|
||||
# # Extract the probabilities for the positive class (index 1) for each token
|
||||
# token_positive_probs = token_probabilities[:, 1].tolist()
|
||||
|
||||
# return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# def run_content_classification_inference(
|
||||
# text_inputs: list[str],
|
||||
# ) -> list[ContentClassificationPrediction]:
|
||||
# """
|
||||
# Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
# creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
# In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
# boost factor.
|
||||
# """
|
||||
|
||||
# def _prob_to_score(prob: float) -> float:
|
||||
# """
|
||||
# Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
# """
|
||||
# _MIN_BASE_SCORE = 0.25
|
||||
# _MAX_BASE_SCORE = 0.75
|
||||
# if prob < _MIN_BASE_SCORE:
|
||||
# raw_score = 0.0
|
||||
# elif prob < _MAX_BASE_SCORE:
|
||||
# raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
# else:
|
||||
# raw_score = 1.0
|
||||
# return (
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
# + (
|
||||
# INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
# - INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
# )
|
||||
# * raw_score
|
||||
# )
|
||||
|
||||
# _BATCH_SIZE = 32
|
||||
# content_model = get_local_information_content_model()
|
||||
|
||||
# # Process inputs in batches
|
||||
# all_output_classes: list[int] = []
|
||||
# all_base_output_probabilities: list[float] = []
|
||||
|
||||
# for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
# batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
# batch_with_prefix = []
|
||||
# batch_indices = []
|
||||
|
||||
# # Pre-allocate results for this batch
|
||||
# batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
# batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# # Pre-process batch to handle long input exceptions
|
||||
# for j, text in enumerate(batch):
|
||||
# if len(text) == 0:
|
||||
# # if no input, treat as non-informative from the model's perspective
|
||||
# batch_output_classes[j] = np.array(0)
|
||||
# batch_probabilities[j] = np.array(0.0)
|
||||
# logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
# elif (
|
||||
# len(text.split())
|
||||
# <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
# ):
|
||||
# # if input is short, use the model
|
||||
# batch_with_prefix.append(
|
||||
# _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
# )
|
||||
# batch_indices.append(j)
|
||||
# else:
|
||||
# # if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
# logger.warning("Input for Content Information Model too long")
|
||||
|
||||
# if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# # Get predictions for the batch
|
||||
# model_output_classes = content_model(batch_with_prefix)
|
||||
# model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# # Place results in the correct positions
|
||||
# for idx, batch_idx in enumerate(batch_indices):
|
||||
# batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
# batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
# 1
|
||||
# ].numpy() # x[1] is prob of the positive class
|
||||
|
||||
# all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
# all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
# logits = [
|
||||
# np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
# for p in all_base_output_probabilities
|
||||
# ]
|
||||
# scaled_logits = [
|
||||
# logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
# for logit in logits
|
||||
# ]
|
||||
# output_probabilities_with_temp = [
|
||||
# np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
# for scaled_logit in scaled_logits
|
||||
# ]
|
||||
|
||||
# prediction_scores = [
|
||||
# _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
# ]
|
||||
|
||||
# content_classification_predictions = [
|
||||
# ContentClassificationPrediction(
|
||||
# predicted_label=predicted_label, content_boost_factor=output_score
|
||||
# )
|
||||
# for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
# ]
|
||||
|
||||
# return content_classification_predictions
|
||||
|
||||
|
||||
# def map_keywords(
|
||||
# input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
# ) -> list[str]:
|
||||
# tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
|
||||
# if not len(tokens) == len(is_keyword):
|
||||
# raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
# if input_ids[0] == tokenizer.cls_token_id:
|
||||
# tokens = tokens[1:]
|
||||
# is_keyword = is_keyword[1:]
|
||||
|
||||
# if input_ids[-1] == tokenizer.sep_token_id:
|
||||
# tokens = tokens[:-1]
|
||||
# is_keyword = is_keyword[:-1]
|
||||
|
||||
# unk_token = tokenizer.unk_token
|
||||
# if unk_token in tokens:
|
||||
# raise ValueError("Unknown token detected in the input")
|
||||
|
||||
# keywords = []
|
||||
# current_keyword = ""
|
||||
|
||||
# for ind, token in enumerate(tokens):
|
||||
# if is_keyword[ind]:
|
||||
# if token.startswith("##"):
|
||||
# current_keyword += token[2:]
|
||||
# else:
|
||||
# if current_keyword:
|
||||
# keywords.append(current_keyword)
|
||||
# current_keyword = token
|
||||
# else:
|
||||
# # If mispredicted a later token of a keyword, add it to the current keyword
|
||||
# # to complete it
|
||||
# if current_keyword:
|
||||
# if len(current_keyword) > 2 and current_keyword.startswith("##"):
|
||||
# current_keyword = current_keyword[2:]
|
||||
|
||||
# else:
|
||||
# keywords.append(current_keyword)
|
||||
# current_keyword = ""
|
||||
|
||||
# if current_keyword:
|
||||
# keywords.append(current_keyword)
|
||||
|
||||
# return keywords
|
||||
|
||||
|
||||
# def clean_keywords(keywords: list[str]) -> list[str]:
|
||||
# cleaned_words = []
|
||||
# for word in keywords:
|
||||
# word = word[:-2] if word.endswith("'s") else word
|
||||
# word = word.replace("/", " ")
|
||||
# word = word.replace("'", "").replace('"', "")
|
||||
# cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()])
|
||||
# return cleaned_words
|
||||
|
||||
|
||||
# def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]:
|
||||
# tokenizer = get_connector_classifier_tokenizer()
|
||||
# model = get_local_connector_classifier()
|
||||
|
||||
# connector_names = req.available_connectors
|
||||
|
||||
# input_ids, attention_mask = tokenize_connector_classification_query(
|
||||
# connector_names,
|
||||
# req.query,
|
||||
# tokenizer,
|
||||
# model.connector_end_token_id,
|
||||
# )
|
||||
# input_ids = input_ids.to(model.device)
|
||||
# attention_mask = attention_mask.to(model.device)
|
||||
|
||||
# global_confidence, classifier_confidence = model(input_ids, attention_mask)
|
||||
|
||||
# if global_confidence.item() < 0.5:
|
||||
# return []
|
||||
|
||||
# passed_connectors = []
|
||||
|
||||
# for i, connector_name in enumerate(connector_names):
|
||||
# if classifier_confidence.view(-1)[i].item() > 0.5:
|
||||
# passed_connectors.append(connector_name)
|
||||
|
||||
# return passed_connectors
|
||||
|
||||
|
||||
# def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]:
|
||||
# tokenizer = get_intent_model_tokenizer()
|
||||
# model_input = tokenizer(
|
||||
# intent_req.query, return_tensors="pt", truncation=False, padding=False
|
||||
# )
|
||||
|
||||
# if len(model_input.input_ids[0]) > 512:
|
||||
# # If the user text is too long, assume it is semantic and keep all words
|
||||
# return True, intent_req.query.split()
|
||||
|
||||
# intent_probs, token_probs = run_inference(model_input)
|
||||
|
||||
# is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold
|
||||
|
||||
# keyword_preds = [
|
||||
# token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs
|
||||
# ]
|
||||
|
||||
# try:
|
||||
# keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds)
|
||||
# except Exception as e:
|
||||
# logger.warning(
|
||||
# f"Failed to extract keywords for query: {intent_req.query} due to {e}"
|
||||
# )
|
||||
# # Fallback to keeping all words
|
||||
# keywords = intent_req.query.split()
|
||||
|
||||
# cleaned_keywords = clean_keywords(keywords)
|
||||
|
||||
# return is_keyword_sequence, cleaned_keywords
|
||||
|
||||
|
||||
# @router.post("/connector-classification")
|
||||
# async def process_connector_classification_request(
|
||||
# classification_request: ConnectorClassificationRequest,
|
||||
# ) -> ConnectorClassificationResponse:
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError(
|
||||
# "Indexing model server should not call connector classification endpoint"
|
||||
# )
|
||||
|
||||
# if len(classification_request.available_connectors) == 0:
|
||||
# return ConnectorClassificationResponse(connectors=[])
|
||||
|
||||
# connectors = run_connector_classification(classification_request)
|
||||
# return ConnectorClassificationResponse(connectors=connectors)
|
||||
|
||||
|
||||
# @router.post("/query-analysis")
|
||||
# async def process_analysis_request(
|
||||
# intent_request: IntentRequest,
|
||||
# ) -> IntentResponse:
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
# is_keyword, keywords = run_analysis(intent_request)
|
||||
# return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
# @router.post("/content-classification")
|
||||
# async def process_content_classification_request(
|
||||
# content_classification_requests: list[str],
|
||||
# ) -> list[ContentClassificationPrediction]:
|
||||
# return run_content_classification_inference(content_classification_requests)
|
||||
154
backend/model_server/legacy/onyx_torch_model.py
Normal file
154
backend/model_server/legacy/onyx_torch_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# import json
|
||||
# import os
|
||||
# from typing import cast
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# import torch
|
||||
# import torch.nn as nn
|
||||
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from transformers import DistilBertConfig
|
||||
|
||||
|
||||
# class HybridClassifier(nn.Module):
|
||||
# def __init__(self) -> None:
|
||||
# from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
# super().__init__()
|
||||
# config = DistilBertConfig()
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
|
||||
# # Keyword tokenwise binary classification layer
|
||||
# self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# # Intent Classifier layers
|
||||
# self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
# self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# self.device = torch.device("cpu")
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# query_ids: torch.Tensor,
|
||||
# query_mask: torch.Tensor,
|
||||
# ) -> dict[str, torch.Tensor]:
|
||||
# outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
|
||||
# sequence_output = outputs.last_hidden_state
|
||||
|
||||
# # Intent classification on the CLS token
|
||||
# cls_token_state = sequence_output[:, 0, :]
|
||||
# pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
# intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# # Keyword classification on all tokens
|
||||
# token_logits = self.keyword_classifier(sequence_output)
|
||||
|
||||
# return {"intent_logits": intent_logits, "token_logits": token_logits}
|
||||
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
|
||||
# model_path = os.path.join(load_directory, "pytorch_model.bin")
|
||||
# config_path = os.path.join(load_directory, "config.json")
|
||||
|
||||
# with open(config_path, "r") as f:
|
||||
# config = json.load(f)
|
||||
# model = cls(**config)
|
||||
|
||||
# if torch.backends.mps.is_available():
|
||||
# # Apple silicon GPU
|
||||
# device = torch.device("mps")
|
||||
# elif torch.cuda.is_available():
|
||||
# device = torch.device("cuda")
|
||||
# else:
|
||||
# device = torch.device("cpu")
|
||||
|
||||
# model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
# model = model.to(device)
|
||||
|
||||
# model.device = device
|
||||
|
||||
# model.eval()
|
||||
# # Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
|
||||
# for param in model.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# return model
|
||||
|
||||
|
||||
# class ConnectorClassifier(nn.Module):
|
||||
# def __init__(self, config: "DistilBertConfig") -> None:
|
||||
# from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
# super().__init__()
|
||||
|
||||
# self.config = config
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
# self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
# self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
# self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# # Token indicating end of connector name, and on which classifier is used
|
||||
# self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
# self.config.connector_end_token
|
||||
# ]
|
||||
|
||||
# self.device = torch.device("cpu")
|
||||
|
||||
# def forward(
|
||||
# self,
|
||||
# input_ids: torch.Tensor,
|
||||
# attention_mask: torch.Tensor,
|
||||
# ) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states = self.distilbert(
|
||||
# input_ids=input_ids, attention_mask=attention_mask
|
||||
# ).last_hidden_state
|
||||
|
||||
# cls_hidden_states = hidden_states[
|
||||
# :, 0, :
|
||||
# ] # Take leap of faith that first token is always [CLS]
|
||||
# global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
# global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
# connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
# connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
# classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
# classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
# return global_confidence, classifier_confidence
|
||||
|
||||
# @classmethod
|
||||
# def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
# from transformers import DistilBertConfig
|
||||
|
||||
# config = cast(
|
||||
# DistilBertConfig,
|
||||
# DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
# )
|
||||
# device = (
|
||||
# torch.device("cuda")
|
||||
# if torch.cuda.is_available()
|
||||
# else (
|
||||
# torch.device("mps")
|
||||
# if torch.backends.mps.is_available()
|
||||
# else torch.device("cpu")
|
||||
# )
|
||||
# )
|
||||
# state_dict = torch.load(
|
||||
# os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
# map_location=device,
|
||||
# weights_only=True,
|
||||
# )
|
||||
|
||||
# model = cls(config)
|
||||
# model.load_state_dict(state_dict)
|
||||
# model.to(device)
|
||||
# model.device = device
|
||||
# model.eval()
|
||||
|
||||
# for param in model.parameters():
|
||||
# param.requires_grad = False
|
||||
|
||||
# return model
|
||||
80
backend/model_server/legacy/reranker.py
Normal file
80
backend/model_server/legacy/reranker.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# import asyncio
|
||||
# from typing import Optional
|
||||
# from typing import TYPE_CHECKING
|
||||
|
||||
# from fastapi import APIRouter
|
||||
# from fastapi import HTTPException
|
||||
|
||||
# from model_server.utils import simple_log_function_time
|
||||
# from onyx.utils.logger import setup_logger
|
||||
# from shared_configs.configs import INDEXING_ONLY
|
||||
# from shared_configs.model_server_models import RerankRequest
|
||||
# from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from sentence_transformers import CrossEncoder
|
||||
|
||||
# logger = setup_logger()
|
||||
|
||||
# router = APIRouter(prefix="/encoder")
|
||||
|
||||
# _RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
|
||||
# def get_local_reranking_model(
|
||||
# model_name: str,
|
||||
# ) -> "CrossEncoder":
|
||||
# global _RERANK_MODEL
|
||||
# from sentence_transformers import CrossEncoder
|
||||
|
||||
# if _RERANK_MODEL is None:
|
||||
# logger.notice(f"Loading {model_name}")
|
||||
# model = CrossEncoder(model_name)
|
||||
# _RERANK_MODEL = model
|
||||
# return _RERANK_MODEL
|
||||
|
||||
|
||||
# @simple_log_function_time()
|
||||
# async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
# cross_encoder = get_local_reranking_model(model_name)
|
||||
# # Run CPU-bound reranking in a thread pool
|
||||
# return await asyncio.get_event_loop().run_in_executor(
|
||||
# None,
|
||||
# lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
|
||||
# )
|
||||
|
||||
|
||||
# @router.post("/cross-encoder-scores")
|
||||
# async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
# """Cross encoders can be purely black box from the app perspective"""
|
||||
# # Only local models should use this endpoint - API providers should make direct API calls
|
||||
# if rerank_request.provider_type is not None:
|
||||
# raise ValueError(
|
||||
# f"Model server reranking endpoint should only be used for local models. "
|
||||
# f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
# )
|
||||
|
||||
# if INDEXING_ONLY:
|
||||
# raise RuntimeError("Indexing model server should not call reranking endpoint")
|
||||
|
||||
# if not rerank_request.documents or not rerank_request.query:
|
||||
# raise HTTPException(
|
||||
# status_code=400, detail="Missing documents or query for reranking"
|
||||
# )
|
||||
# if not all(rerank_request.documents):
|
||||
# raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
# try:
|
||||
# # At this point, provider_type is None, so handle local reranking
|
||||
# sim_scores = await local_rerank(
|
||||
# query=rerank_request.query,
|
||||
# docs=rerank_request.documents,
|
||||
# model_name=rerank_request.model_name,
|
||||
# )
|
||||
# return RerankResponse(scores=sim_scores)
|
||||
|
||||
# except Exception as e:
|
||||
# logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
# raise HTTPException(
|
||||
# status_code=500, detail="Failed to run Cross-Encoder reranking"
|
||||
# )
|
||||
@@ -12,11 +12,8 @@ from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from transformers import logging as transformer_logging
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
@@ -30,7 +27,6 @@ from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.configs import SKIP_WARM_UP
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
@@ -92,18 +88,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not SKIP_WARM_UP:
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice("Warming up intent model for inference model server")
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"Warming up content information model for indexing model server"
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
else:
|
||||
logger.notice("Skipping model warmup due to SKIP_WARM_UP=true")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -123,7 +107,6 @@ def get_model_app() -> FastAPI:
|
||||
|
||||
application.include_router(management_router)
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
request_id_prefix = "INF"
|
||||
if INDEXING_ONLY:
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DistilBertConfig # type: ignore
|
||||
|
||||
|
||||
class HybridClassifier(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
from transformers import DistilBertConfig, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
config = DistilBertConfig()
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
|
||||
# Keyword tokenwise binary classification layer
|
||||
self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
# Intent Classifier layers
|
||||
self.pre_classifier = nn.Linear(config.dim, config.dim)
|
||||
self.intent_classifier = nn.Linear(config.dim, 2)
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query_ids: torch.Tensor,
|
||||
query_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
|
||||
sequence_output = outputs.last_hidden_state
|
||||
|
||||
# Intent classification on the CLS token
|
||||
cls_token_state = sequence_output[:, 0, :]
|
||||
pre_classifier_out = self.pre_classifier(cls_token_state)
|
||||
intent_logits = self.intent_classifier(pre_classifier_out)
|
||||
|
||||
# Keyword classification on all tokens
|
||||
token_logits = self.keyword_classifier(sequence_output)
|
||||
|
||||
return {"intent_logits": intent_logits, "token_logits": token_logits}
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, load_directory: str) -> "HybridClassifier":
|
||||
model_path = os.path.join(load_directory, "pytorch_model.bin")
|
||||
config_path = os.path.join(load_directory, "config.json")
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
model = cls(**config)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# Apple silicon GPU
|
||||
device = torch.device("mps")
|
||||
elif torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||
model = model.to(device)
|
||||
|
||||
model.device = device
|
||||
|
||||
model.eval()
|
||||
# Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ConnectorClassifier(nn.Module):
|
||||
def __init__(self, config: "DistilBertConfig") -> None:
|
||||
from transformers import DistilBertTokenizer, DistilBertModel
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.distilbert = DistilBertModel(config)
|
||||
config = self.distilbert.config # type: ignore
|
||||
self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
# Token indicating end of connector name, and on which classifier is used
|
||||
self.connector_end_token_id = self.tokenizer.get_vocab()[
|
||||
self.config.connector_end_token
|
||||
]
|
||||
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.distilbert( # type: ignore
|
||||
input_ids=input_ids, attention_mask=attention_mask
|
||||
).last_hidden_state
|
||||
|
||||
cls_hidden_states = hidden_states[
|
||||
:, 0, :
|
||||
] # Take leap of faith that first token is always [CLS]
|
||||
global_logits = self.connector_global_classifier(cls_hidden_states).view(-1)
|
||||
global_confidence = torch.sigmoid(global_logits).view(-1)
|
||||
|
||||
connector_end_position_ids = input_ids == self.connector_end_token_id
|
||||
connector_end_hidden_states = hidden_states[connector_end_position_ids]
|
||||
classifier_output = self.connector_match_classifier(connector_end_hidden_states)
|
||||
classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1)
|
||||
|
||||
return global_confidence, classifier_confidence
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier":
|
||||
from transformers import DistilBertConfig
|
||||
|
||||
config = cast(
|
||||
DistilBertConfig,
|
||||
DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")),
|
||||
)
|
||||
device = (
|
||||
torch.device("cuda")
|
||||
if torch.cuda.is_available()
|
||||
else (
|
||||
torch.device("mps")
|
||||
if torch.backends.mps.is_available()
|
||||
else torch.device("cpu")
|
||||
)
|
||||
)
|
||||
state_dict = torch.load(
|
||||
os.path.join(repo_dir, "pytorch_model.pt"),
|
||||
map_location=device,
|
||||
weights_only=True,
|
||||
)
|
||||
|
||||
model = cls(config)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.device = device
|
||||
model.eval()
|
||||
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
return model
|
||||
@@ -43,7 +43,7 @@ def get_access_for_document(
|
||||
versioned_get_access_for_document_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_document"
|
||||
)
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
return versioned_get_access_for_document_fn(document_id, db_session)
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
@@ -93,9 +93,7 @@ def get_access_for_documents(
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session
|
||||
) # type: ignore
|
||||
return versioned_get_access_for_documents_fn(document_ids, db_session)
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
@@ -113,7 +111,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
|
||||
versioned_acl_for_user_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "_get_acl_for_user"
|
||||
)
|
||||
return versioned_acl_for_user_fn(user, db_session) # type: ignore
|
||||
return versioned_acl_for_user_fn(user, db_session)
|
||||
|
||||
|
||||
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:
|
||||
|
||||
107
backend/onyx/auth/captcha.py
Normal file
107
backend/onyx/auth/captcha.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Captcha verification for user registration."""
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import CAPTCHA_ENABLED
|
||||
from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD
|
||||
from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify"
|
||||
|
||||
|
||||
class CaptchaVerificationError(Exception):
|
||||
"""Raised when captcha verification fails."""
|
||||
|
||||
|
||||
class RecaptchaResponse(BaseModel):
|
||||
"""Response from Google reCAPTCHA verification API."""
|
||||
|
||||
success: bool
|
||||
score: float | None = None # Only present for reCAPTCHA v3
|
||||
action: str | None = None
|
||||
challenge_ts: str | None = None
|
||||
hostname: str | None = None
|
||||
error_codes: list[str] | None = Field(default=None, alias="error-codes")
|
||||
|
||||
|
||||
def is_captcha_enabled() -> bool:
|
||||
"""Check if captcha verification is enabled."""
|
||||
return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY)
|
||||
|
||||
|
||||
async def verify_captcha_token(
|
||||
token: str,
|
||||
expected_action: str = "signup",
|
||||
) -> None:
|
||||
"""
|
||||
Verify a reCAPTCHA token with Google's API.
|
||||
|
||||
Args:
|
||||
token: The reCAPTCHA response token from the client
|
||||
expected_action: Expected action name for v3 verification
|
||||
|
||||
Raises:
|
||||
CaptchaVerificationError: If verification fails
|
||||
"""
|
||||
if not is_captcha_enabled():
|
||||
return
|
||||
|
||||
if not token:
|
||||
raise CaptchaVerificationError("Captcha token is required")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
RECAPTCHA_VERIFY_URL,
|
||||
data={
|
||||
"secret": RECAPTCHA_SECRET_KEY,
|
||||
"response": token,
|
||||
},
|
||||
timeout=10.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
result = RecaptchaResponse(**data)
|
||||
|
||||
if not result.success:
|
||||
error_codes = result.error_codes or ["unknown-error"]
|
||||
logger.warning(f"Captcha verification failed: {error_codes}")
|
||||
raise CaptchaVerificationError(
|
||||
f"Captcha verification failed: {', '.join(error_codes)}"
|
||||
)
|
||||
|
||||
# For reCAPTCHA v3, also check the score
|
||||
if result.score is not None:
|
||||
if result.score < RECAPTCHA_SCORE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: suspicious activity detected"
|
||||
)
|
||||
|
||||
# Optionally verify the action matches
|
||||
if result.action and result.action != expected_action:
|
||||
logger.warning(
|
||||
f"Captcha action mismatch: {result.action} != {expected_action}"
|
||||
)
|
||||
raise CaptchaVerificationError(
|
||||
"Captcha verification failed: action mismatch"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Captcha verification passed: score={result.score}, "
|
||||
f"action={result.action}"
|
||||
)
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(f"Captcha API request failed: {e}")
|
||||
# In case of API errors, we might want to allow registration
|
||||
# to prevent blocking legitimate users. This is a policy decision.
|
||||
raise CaptchaVerificationError("Captcha verification service unavailable")
|
||||
192
backend/onyx/auth/disposable_email_validator.py
Normal file
192
backend/onyx/auth/disposable_email_validator.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Utility to validate and block disposable/temporary email addresses.
|
||||
|
||||
This module fetches a list of known disposable email domains from a remote source
|
||||
and caches them for performance. It's used during user registration to prevent
|
||||
abuse from temporary email services.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Set
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DisposableEmailValidator:
|
||||
"""
|
||||
Thread-safe singleton validator for disposable email domains.
|
||||
|
||||
Fetches and caches the list of disposable domains, with periodic refresh.
|
||||
"""
|
||||
|
||||
_instance: "DisposableEmailValidator | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "DisposableEmailValidator":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Check if already initialized using a try/except to avoid type issues
|
||||
try:
|
||||
if self._initialized:
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self._domains: Set[str] = set()
|
||||
self._last_fetch_time: float = 0
|
||||
self._fetch_lock = threading.Lock()
|
||||
# Cache for 1 hour
|
||||
self._cache_duration = 3600
|
||||
# Hardcoded fallback list of common disposable domains
|
||||
# This ensures we block at least these even if the remote fetch fails
|
||||
self._fallback_domains = {
|
||||
"trashlify.com",
|
||||
"10minutemail.com",
|
||||
"guerrillamail.com",
|
||||
"mailinator.com",
|
||||
"tempmail.com",
|
||||
"throwaway.email",
|
||||
"yopmail.com",
|
||||
"temp-mail.org",
|
||||
"getnada.com",
|
||||
"maildrop.cc",
|
||||
}
|
||||
# Set initialized flag last to prevent race conditions
|
||||
self._initialized: bool = True
|
||||
|
||||
def _should_refresh(self) -> bool:
|
||||
"""Check if the cached domains should be refreshed."""
|
||||
return (time.time() - self._last_fetch_time) > self._cache_duration
|
||||
|
||||
def _fetch_domains(self) -> Set[str]:
|
||||
"""
|
||||
Fetch disposable email domains from the configured URL.
|
||||
|
||||
Returns:
|
||||
Set of domain strings (lowercased)
|
||||
"""
|
||||
if not DISPOSABLE_EMAIL_DOMAINS_URL:
|
||||
logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured")
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}"
|
||||
)
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
domains_list = response.json()
|
||||
|
||||
if not isinstance(domains_list, list):
|
||||
logger.error(
|
||||
f"Expected list from disposable domains URL, got {type(domains_list)}"
|
||||
)
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
# Convert all to lowercase and create set
|
||||
domains = {domain.lower().strip() for domain in domains_list if domain}
|
||||
|
||||
# Always include fallback domains
|
||||
domains.update(self._fallback_domains)
|
||||
|
||||
logger.info(
|
||||
f"Successfully fetched {len(domains)} disposable email domains"
|
||||
)
|
||||
return domains
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch disposable domains: {e}")
|
||||
|
||||
# On error, return fallback domains
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
def get_domains(self) -> Set[str]:
|
||||
"""
|
||||
Get the cached set of disposable email domains.
|
||||
Refreshes the cache if needed.
|
||||
|
||||
Returns:
|
||||
Set of disposable domain strings (lowercased)
|
||||
"""
|
||||
# Fast path: return cached domains if still fresh
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
# Slow path: need to refresh
|
||||
with self._fetch_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
self._domains = self._fetch_domains()
|
||||
self._last_fetch_time = time.time()
|
||||
return self._domains.copy()
|
||||
|
||||
def is_disposable(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable domain.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email domain is disposable, False otherwise
|
||||
"""
|
||||
if not email or "@" not in email:
|
||||
return False
|
||||
|
||||
parts = email.split("@")
|
||||
if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user
|
||||
return False
|
||||
|
||||
domain = parts[1].lower().strip()
|
||||
if not domain: # Domain part must not be empty
|
||||
return False
|
||||
|
||||
disposable_domains = self.get_domains()
|
||||
return domain in disposable_domains
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_validator = DisposableEmailValidator()
|
||||
|
||||
|
||||
def is_disposable_email(email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable/temporary domain.
|
||||
|
||||
This is a convenience function that uses the global validator instance.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email uses a disposable domain, False otherwise
|
||||
"""
|
||||
return _validator.is_disposable(email)
|
||||
|
||||
|
||||
def refresh_disposable_domains() -> None:
|
||||
"""
|
||||
Force a refresh of the disposable domains list.
|
||||
|
||||
This can be called manually if you want to update the list
|
||||
without waiting for the cache to expire.
|
||||
"""
|
||||
_validator._last_fetch_time = 0
|
||||
_validator.get_domains()
|
||||
@@ -40,6 +40,8 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
captcha_token: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
|
||||
@@ -60,6 +60,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
@@ -117,7 +118,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -248,13 +249,23 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -292,11 +303,57 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
# Verify captcha if enabled (for cloud signup protection)
|
||||
from onyx.auth.captcha import CaptchaVerificationError
|
||||
from onyx.auth.captcha import is_captcha_enabled
|
||||
from onyx.auth.captcha import verify_captcha_token
|
||||
|
||||
if is_captcha_enabled() and request is not None:
|
||||
# Get captcha token from request body or headers
|
||||
captcha_token = None
|
||||
if hasattr(user_create, "captcha_token"):
|
||||
captcha_token = getattr(user_create, "captcha_token", None)
|
||||
|
||||
# Also check headers as a fallback
|
||||
if not captcha_token:
|
||||
captcha_token = request.headers.get("X-Captcha-Token")
|
||||
|
||||
try:
|
||||
await verify_captcha_token(
|
||||
captcha_token or "", expected_action="signup"
|
||||
)
|
||||
except CaptchaVerificationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": str(e)},
|
||||
)
|
||||
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
# Check for disposable emails BEFORE provisioning tenant
|
||||
# This prevents creating tenants for throwaway email addresses
|
||||
try:
|
||||
verify_email_domain(user_create.email)
|
||||
except HTTPException as e:
|
||||
# Log blocked disposable email attempts
|
||||
if (
|
||||
e.status_code == status.HTTP_400_BAD_REQUEST
|
||||
and "Disposable email" in str(e.detail)
|
||||
):
|
||||
domain = (
|
||||
user_create.email.split("@")[-1]
|
||||
if "@" in user_create.email
|
||||
else "unknown"
|
||||
)
|
||||
logger.warning(
|
||||
f"Blocked disposable email registration attempt: {domain}",
|
||||
extra={"email_domain": domain},
|
||||
)
|
||||
raise
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
@@ -318,8 +375,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
# Check invite list based on deployment mode
|
||||
if MULTI_TENANT:
|
||||
# Multi-tenant: Only require invite for existing tenants
|
||||
# New tenant creation (first user) doesn't require an invite
|
||||
user_count = await get_user_count()
|
||||
if user_count > 0:
|
||||
# Tenant already has users - require invite for new users
|
||||
verify_email_is_invited(user_create.email)
|
||||
else:
|
||||
# Single-tenant: Check invite list (skips if SAML/OIDC or no list configured)
|
||||
verify_email_is_invited(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
@@ -338,9 +404,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
user_created = True
|
||||
except IntegrityError as error:
|
||||
# Race condition: another request created the same user after the
|
||||
@@ -604,10 +668,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if (
|
||||
user.oidc_expiry is not None # type: ignore
|
||||
and not TRACK_EXTERNAL_IDP_EXPIRY
|
||||
):
|
||||
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
remove_user_from_invited_users(user.email)
|
||||
@@ -653,19 +714,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
event_type = (
|
||||
MilestoneRecordType.USER_SIGNED_UP
|
||||
if user_count == 1
|
||||
else MilestoneRecordType.MULTIPLE_USERS
|
||||
)
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=event_type,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
event=MilestoneRecordType.USER_SIGNED_UP,
|
||||
)
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
@@ -1186,7 +1239,7 @@ async def _sync_jwt_oidc_expiry(
|
||||
return
|
||||
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
|
||||
user.oidc_expiry = oidc_expiry # type: ignore
|
||||
user.oidc_expiry = oidc_expiry
|
||||
return
|
||||
|
||||
if user.oidc_expiry is not None:
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -515,6 +516,9 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
|
||||
@@ -98,8 +98,5 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,8 +2,12 @@ import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -53,16 +57,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
@@ -171,13 +165,32 @@ if ENTERPRISE_EDITION_ENABLED:
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
# Add the Auto LLM update task if the config URL is set (has a default)
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"name": "check-for-auto-llm-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE,
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add scheduled eval task if datasets are configured
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "scheduled-eval-pipeline",
|
||||
"task": OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
# run every Sunday at midnight UTC
|
||||
"schedule": crontab(
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
@@ -12,6 +12,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -25,14 +26,14 @@ from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
|
||||
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state
|
||||
from onyx.background.celery.tasks.docprocessing.utils import should_index
|
||||
from onyx.background.celery.tasks.docprocessing.utils import (
|
||||
try_creating_docfetching_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.models import DocProcessingContext
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
from onyx.background.indexing.checkpointing_utils import (
|
||||
@@ -40,11 +41,14 @@ from onyx.background.indexing.checkpointing_utils import (
|
||||
)
|
||||
from onyx.background.indexing.index_attempt_utils import cleanup_index_attempts
|
||||
from onyx.background.indexing.index_attempt_utils import get_old_index_attempts
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -58,11 +62,9 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -95,9 +97,6 @@ from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -108,11 +107,13 @@ from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import USAGE_LIMITS_ENABLED
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
@@ -539,14 +540,15 @@ def check_indexing_completion(
|
||||
]:
|
||||
# User file connectors must be paused on success
|
||||
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
|
||||
# TODO: figure out why this doesn't pause connectors during swap
|
||||
cc_pair.status = (
|
||||
ConnectorCredentialPairStatus.PAUSED
|
||||
if cc_pair.is_user_file
|
||||
else ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
)
|
||||
|
||||
# Clear repeated error state on success
|
||||
if cc_pair.in_repeated_error_state:
|
||||
cc_pair.in_repeated_error_state = False
|
||||
@@ -804,13 +806,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=True
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=current_search_settings.id
|
||||
)
|
||||
)
|
||||
|
||||
primary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
primary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Get CC pairs for secondary search settings
|
||||
secondary_cc_pair_ids: list[int] = []
|
||||
@@ -826,30 +823,47 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=not include_paused
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=secondary_search_settings.id
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
if is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=current_search_settings.id,
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# if already in repeated error state, don't do anything
|
||||
# this is important so that we don't keep pausing the connector
|
||||
# immediately upon a user un-pausing it to manually re-trigger and
|
||||
# recover.
|
||||
if (
|
||||
cc_pair
|
||||
and not cc_pair.in_repeated_error_state
|
||||
and is_in_repeated_error_state(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
):
|
||||
set_cc_pair_repeated_error_state(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
|
||||
# NOTE: At this point, we haven't done heavy checks on whether or not the CC pairs should actually be indexed
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
@@ -1274,6 +1288,26 @@ def docprocessing_task(
|
||||
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _check_chunk_usage_limit(tenant_id: str) -> None:
|
||||
"""Check if chunk indexing usage limit has been exceeded.
|
||||
|
||||
Raises UsageLimitExceededError if the limit is exceeded.
|
||||
"""
|
||||
if not USAGE_LIMITS_ENABLED:
|
||||
return
|
||||
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.server.usage_limits import check_usage_and_raise
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
check_usage_and_raise(
|
||||
db_session=db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
tenant_id=tenant_id,
|
||||
pending_amount=0, # Just check current usage
|
||||
)
|
||||
|
||||
|
||||
def _docprocessing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
@@ -1285,6 +1319,25 @@ def _docprocessing_task(
|
||||
if tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Check if chunk indexing usage limit has been exceeded before processing
|
||||
if USAGE_LIMITS_ENABLED:
|
||||
try:
|
||||
_check_chunk_usage_limit(tenant_id)
|
||||
except HTTPException as e:
|
||||
# Log the error and fail the indexing attempt
|
||||
task_logger.error(
|
||||
f"Chunk indexing usage limit exceeded for tenant {tenant_id}: {e}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
task_logger.info(
|
||||
f"Processing document batch: "
|
||||
f"attempt={index_attempt_id} "
|
||||
@@ -1383,10 +1436,6 @@ def _docprocessing_task(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = (
|
||||
InformationContentClassificationModel()
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt.search_settings,
|
||||
None,
|
||||
@@ -1404,8 +1453,13 @@ def _docprocessing_task(
|
||||
)
|
||||
|
||||
# Process documents through indexing pipeline
|
||||
connector_source = (
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
task_logger.info(
|
||||
f"Processing {len(documents)} documents through indexing pipeline"
|
||||
f"Processing {len(documents)} documents through indexing pipeline: "
|
||||
f"cc_pair_id={cc_pair_id}, source={connector_source}, "
|
||||
f"batch_num={batch_num}"
|
||||
)
|
||||
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
@@ -1419,7 +1473,6 @@ def _docprocessing_task(
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=True, # Documents are already filtered during extraction
|
||||
db_session=db_session,
|
||||
@@ -1429,6 +1482,23 @@ def _docprocessing_task(
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
# Track chunk indexing usage for cloud usage limits
|
||||
if USAGE_LIMITS_ENABLED and index_pipeline_result.total_chunks > 0:
|
||||
try:
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
|
||||
with get_session_with_current_tenant() as usage_db_session:
|
||||
increment_usage(
|
||||
db_session=usage_db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
amount=index_pipeline_result.total_chunks,
|
||||
)
|
||||
usage_db_session.commit()
|
||||
except Exception as e:
|
||||
# Log but don't fail indexing if usage tracking fails
|
||||
task_logger.warning(f"Failed to track chunk indexing usage: {e}")
|
||||
|
||||
# Update batch completion and document counts atomically using database coordination
|
||||
|
||||
with get_session_with_current_tenant() as db_session, cross_batch_db_lock:
|
||||
@@ -1495,6 +1565,8 @@ def _docprocessing_task(
|
||||
|
||||
# FIX: Explicitly clear document batch from memory and force garbage collection
|
||||
# This helps prevent memory accumulation across multiple batches
|
||||
# NOTE: Thread-local event loops in embedding threads are cleaned up automatically
|
||||
# via the _cleanup_thread_local decorator in search_nlp_models.py
|
||||
del documents
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -1,31 +1,21 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -135,18 +125,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
cc_pair_id: int, search_settings_id: int, db_session: Session
|
||||
cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session
|
||||
) -> bool:
|
||||
"""Checks if the cc pair / search setting combination is in a repeated error state."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise RuntimeError(
|
||||
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
|
||||
)
|
||||
|
||||
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
|
||||
number_of_failed_attempts_in_a_row_needed = (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
@@ -155,7 +136,7 @@ def is_in_repeated_error_state(
|
||||
)
|
||||
|
||||
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_id,
|
||||
limit=number_of_failed_attempts_in_a_row_needed,
|
||||
db_session=db_session,
|
||||
@@ -189,7 +170,7 @@ def should_index(
|
||||
db_session=db_session,
|
||||
)
|
||||
all_recent_errored = is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair.id,
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -298,112 +279,3 @@ def should_index(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_docfetching_task(
|
||||
celery_app: Celery,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings: SearchSettings,
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger indexing immediately.
|
||||
|
||||
Now uses database-based coordination instead of Redis fencing.
|
||||
"""
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
# we need to serialize any attempt to trigger indexing since it can be triggered
|
||||
# either via celery beat or manually (API call)
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
index_attempt_id = None
|
||||
try:
|
||||
# Basic status checks
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# Generate custom task ID for tracking
|
||||
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
|
||||
|
||||
# Try to create a new index attempt using database coordination
|
||||
# This replaces the Redis fencing mechanism
|
||||
index_attempt_id = IndexingCoordination.try_create_index_attempt(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
celery_task_id=custom_task_id,
|
||||
from_beginning=reindex,
|
||||
)
|
||||
|
||||
if index_attempt_id is None:
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Send the task to Celery
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
|
||||
|
||||
task_logger.info(
|
||||
f"Created docfetching task: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id} "
|
||||
f"attempt_id={index_attempt_id} "
|
||||
f"celery_task_id={custom_task_id}"
|
||||
)
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
# Clean up on failure
|
||||
if index_attempt_id is not None:
|
||||
mark_attempt_failed(index_attempt_id, db_session)
|
||||
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return index_attempt_id
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
@@ -33,3 +39,109 @@ def eval_run_task(
|
||||
except Exception:
|
||||
logger.error("Failed to run eval task")
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT * 5, # Allow more time for multiple datasets
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def scheduled_eval_task(self: Task, **kwargs: Any) -> None:
|
||||
"""
|
||||
Scheduled task to run evaluations on configured datasets.
|
||||
Runs weekly on Sunday at midnight UTC.
|
||||
|
||||
Configure via environment variables (with defaults):
|
||||
- SCHEDULED_EVAL_DATASET_NAMES: Comma-separated list of Braintrust dataset names
|
||||
- SCHEDULED_EVAL_PERMISSIONS_EMAIL: Email for search permissions (default: roshan@onyx.app)
|
||||
- SCHEDULED_EVAL_PROJECT: Braintrust project name
|
||||
"""
|
||||
if not BRAINTRUST_API_KEY:
|
||||
logger.error("BRAINTRUST_API_KEY is not configured, cannot run scheduled evals")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PROJECT:
|
||||
logger.error(
|
||||
"SCHEDULED_EVAL_PROJECT is not configured, cannot run scheduled evals"
|
||||
)
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_DATASET_NAMES:
|
||||
logger.info("No scheduled eval datasets configured, skipping")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PERMISSIONS_EMAIL:
|
||||
logger.error("SCHEDULED_EVAL_PERMISSIONS_EMAIL not configured")
|
||||
return
|
||||
|
||||
project_name = SCHEDULED_EVAL_PROJECT
|
||||
dataset_names = SCHEDULED_EVAL_DATASET_NAMES
|
||||
permissions_email = SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
|
||||
# Create a timestamp for the scheduled run
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' "
|
||||
f"with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
)
|
||||
|
||||
pipeline_start = datetime.now(timezone.utc)
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for dataset_name in dataset_names:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
error_message: str | None = None
|
||||
success = False
|
||||
|
||||
# Create informative experiment name for scheduled runs
|
||||
experiment_name = f"{dataset_name} - {run_timestamp}"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Running scheduled eval for dataset: {dataset_name} "
|
||||
f"(project: {project_name})"
|
||||
)
|
||||
|
||||
configuration = EvalConfigurationOptions(
|
||||
search_permissions_email=permissions_email,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=project_name,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
||||
result = run_eval(
|
||||
configuration=configuration,
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
success = result.success
|
||||
logger.info(f"Completed eval for {dataset_name}: success={success}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run scheduled eval for {dataset_name}")
|
||||
error_message = str(e)
|
||||
success = False
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"success": success,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"error_message": error_message,
|
||||
}
|
||||
)
|
||||
|
||||
pipeline_end = datetime.now(timezone.utc)
|
||||
total_duration = (pipeline_end - pipeline_start).total_seconds()
|
||||
|
||||
passed_count = sum(1 for r in results if r["success"])
|
||||
logger.info(
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed "
|
||||
f"in {total_duration:.1f}s"
|
||||
)
|
||||
|
||||
@@ -1,141 +1,57 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
name=OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
soft_time_limit=300, # 5 minute timeout
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Periodic task to fetch LLM model updates from GitHub
|
||||
and sync them to providers in Auto mode.
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
This task checks the GitHub-hosted config file and updates all
|
||||
providers that have is_auto_mode=True.
|
||||
"""
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
task_logger.debug("AUTO_LLM_CONFIG_URL not configured, skipping")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(
|
||||
model_configuration.name
|
||||
for model_configuration in default_provider.model_configurations
|
||||
)
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
else:
|
||||
task_logger.debug("No model updates applied")
|
||||
|
||||
# Update the provider's model list
|
||||
# Remove models that are no longer available
|
||||
db_session.query(ModelConfiguration).filter(
|
||||
ModelConfiguration.llm_provider_id == default_provider.id,
|
||||
ModelConfiguration.name.notin_(available_models),
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
# Add new models
|
||||
for available_model_name in available_models:
|
||||
db_session.merge(
|
||||
ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=available_model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
)
|
||||
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
except Exception:
|
||||
task_logger.exception("Error in auto LLM update task")
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
@@ -886,9 +886,7 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
|
||||
n_user_file_processing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -924,7 +922,6 @@ def monitor_celery_queues_helper(
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"user_file_delete={n_user_file_delete} "
|
||||
|
||||
@@ -55,8 +55,8 @@ class RetryDocumentIndex:
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields | None,
|
||||
user_fields: VespaDocumentUserFields | None,
|
||||
) -> int:
|
||||
return self.index.update_single(
|
||||
) -> None:
|
||||
self.index.update_single(
|
||||
doc_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
|
||||
@@ -95,7 +95,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
@@ -114,7 +113,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
chunks_affected = retry_index.delete_single(
|
||||
_ = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
@@ -157,7 +156,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
@@ -187,7 +186,6 @@ def document_by_cc_pair_cleanup_task(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -19,11 +18,9 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -32,28 +29,18 @@ from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@@ -257,10 +244,6 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
information_content_classification_model = (
|
||||
InformationContentClassificationModel()
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings,
|
||||
None,
|
||||
@@ -275,7 +258,6 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
@@ -597,7 +579,7 @@ def process_single_user_file_project_sync(
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
@@ -606,7 +588,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -626,312 +608,3 @@ def process_single_user_file_project_sync(
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
|
||||
user_prefix = "USER_FILE_CONNECTOR__"
|
||||
file_prefix = "FILE_CONNECTOR__"
|
||||
if old_id.startswith(user_prefix):
|
||||
remainder = old_id[len(user_prefix) :]
|
||||
return file_prefix + remainder
|
||||
return old_id
|
||||
|
||||
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
sa.select(SearchDoc).where(
|
||||
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
for sd in search_docs:
|
||||
doc_id = sd.document_id
|
||||
if search_doc_map.get(doc_id) is None:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
updated_chunks = retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = updated_chunks
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
@@ -501,7 +501,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
@@ -515,10 +515,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
f"doc={document_id} " f"action=sync " f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except SoftTimeLimitExceeded:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -21,7 +20,6 @@ from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -32,11 +30,8 @@ from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocExtractionContext
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
@@ -49,34 +44,16 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.index_attempt import transition_attempt_to_in_progress
|
||||
from onyx.db.index_attempt import update_docs_indexed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.document_indexing_adapter import (
|
||||
DocumentIndexingBatchAdapter,
|
||||
)
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -272,583 +249,6 @@ def _check_failure_threshold(
|
||||
)
|
||||
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >2 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
2. Embed and index these documents into the chosen datastore (vespa)
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
|
||||
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(
|
||||
db_session_temp,
|
||||
index_attempt_id,
|
||||
eager_load_cc_pair=True,
|
||||
eager_load_search_settings=True,
|
||||
)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
|
||||
)
|
||||
|
||||
if index_attempt_start.search_settings is None:
|
||||
raise ValueError(
|
||||
"Search settings must be set for indexing. This should not be possible."
|
||||
)
|
||||
|
||||
db_connector = index_attempt_start.connector_credential_pair.connector
|
||||
db_credential = index_attempt_start.connector_credential_pair.credential
|
||||
is_primary = (
|
||||
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
from_beginning = index_attempt_start.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt_start.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
ctx = DocExtractionContext(
|
||||
index_name=index_attempt_start.search_settings.index_name,
|
||||
cc_pair_id=index_attempt_start.connector_credential_pair.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
source=db_connector.source,
|
||||
earliest_index_time=(
|
||||
db_connector.indexing_start.timestamp()
|
||||
if db_connector.indexing_start
|
||||
else 0
|
||||
),
|
||||
from_beginning=from_beginning,
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary=is_primary,
|
||||
should_fetch_permissions_during_indexing=(
|
||||
index_attempt_start.connector_credential_pair.access_type
|
||||
== AccessType.SYNC
|
||||
and source_should_fetch_permissions_during_indexing(db_connector.source)
|
||||
and is_primary
|
||||
# if we've already successfully indexed, let the doc_sync job
|
||||
# take care of doc-level permissions
|
||||
and (from_beginning or not has_successful_attempt)
|
||||
),
|
||||
search_settings_status=index_attempt_start.search_settings.status,
|
||||
doc_extraction_complete_batch_num=None,
|
||||
)
|
||||
|
||||
last_successful_index_poll_range_end = (
|
||||
ctx.earliest_index_time
|
||||
if ctx.from_beginning
|
||||
else get_last_successful_attempt_poll_range_end(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
earliest_index=ctx.earliest_index_time,
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
)
|
||||
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
|
||||
window_start = datetime.fromtimestamp(
|
||||
last_successful_index_poll_range_end, tz=timezone.utc
|
||||
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
|
||||
else:
|
||||
# don't go into "negative" time if we've never indexed before
|
||||
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
most_recent_attempt = next(
|
||||
iter(
|
||||
get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt_start.search_settings_id,
|
||||
db_session=db_session_temp,
|
||||
limit=1,
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# if the last attempt failed, try and use the same window. This is necessary
|
||||
# to ensure correctness with checkpointing. If we don't do this, things like
|
||||
# new slack channels could be missed (since existing slack channels are
|
||||
# cached as part of the checkpoint).
|
||||
if (
|
||||
most_recent_attempt
|
||||
and most_recent_attempt.poll_range_end
|
||||
and (
|
||||
most_recent_attempt.status == IndexingStatus.FAILED
|
||||
or most_recent_attempt.status == IndexingStatus.CANCELED
|
||||
)
|
||||
):
|
||||
window_end = most_recent_attempt.poll_range_end
|
||||
else:
|
||||
window_end = datetime.now(tz=timezone.utc)
|
||||
|
||||
# add start/end now that they have been set
|
||||
index_attempt_start.poll_range_start = window_start
|
||||
index_attempt_start.poll_range_end = window_end
|
||||
db_session_temp.add(index_attempt_start)
|
||||
db_session_temp.commit()
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=index_attempt_start.search_settings,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
# Initialize memory tracer. NOTE: won't actually do anything if
|
||||
# `INDEXING_TRACER_INTERVAL` is 0.
|
||||
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
|
||||
total_failures = 0
|
||||
batch_num = 0
|
||||
net_doc_change = 0
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
index_attempt: IndexAttempt | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session_temp, index_attempt_id, eager_load_cc_pair=True
|
||||
)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
|
||||
connector_runner = _get_connector_runner(
|
||||
db_session=db_session_temp,
|
||||
attempt=index_attempt,
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
include_permissions=ctx.should_fetch_permissions_during_indexing,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
# the beginning in order to avoid weird interactions between
|
||||
# checkpointing / failure handling
|
||||
# OR
|
||||
# if the last attempt was successful
|
||||
if index_attempt.from_beginning or (
|
||||
most_recent_attempt and most_recent_attempt.status.is_successful()
|
||||
):
|
||||
checkpoint = connector_runner.connector.build_dummy_checkpoint()
|
||||
else:
|
||||
checkpoint, _ = get_latest_valid_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
window_start=window_start,
|
||||
window_end=window_end,
|
||||
connector=connector_runner.connector,
|
||||
)
|
||||
|
||||
# save the initial checkpoint to have a proper record of the
|
||||
# "last used checkpoint"
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
unresolved_errors = get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
unresolved_only=True,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = (
|
||||
defaultdict(list)
|
||||
)
|
||||
for error in unresolved_errors:
|
||||
if error.document_id:
|
||||
doc_id_to_unresolved_errors[error.document_id].append(error)
|
||||
|
||||
entity_based_unresolved_errors = [
|
||||
error for error in unresolved_errors if error.entity_id
|
||||
]
|
||||
|
||||
while checkpoint.has_more:
|
||||
logger.info(
|
||||
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
|
||||
)
|
||||
for document_batch, failure, next_checkpoint in connector_runner.run(
|
||||
checkpoint
|
||||
):
|
||||
# Check if connector is disabled mid run and stop if so unless it's the secondary
|
||||
# index being built. We want to populate it even for paused connectors
|
||||
# Often paused connectors are sources that aren't updated frequently but the
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# NOTE: this progress callback runs on every loop. We've seen cases
|
||||
# where we loop many times with no new documents and eventually time
|
||||
# out, so only doing the callback after indexing isn't sufficient.
|
||||
callback.progress("_run_indexing", 0)
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp,
|
||||
ctx.cc_pair_id,
|
||||
ctx.search_settings_status,
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures, document_count, batch_num, failure
|
||||
)
|
||||
|
||||
# save the new checkpoint (if one is provided)
|
||||
if next_checkpoint:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
# below is all document processing logic, so if no batch we can just continue
|
||||
if document_batch is None:
|
||||
continue
|
||||
|
||||
batch_description = []
|
||||
|
||||
# Generate an ID that can be used to correlate activity between here
|
||||
# and the embedding model server
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
|
||||
index_attempt_md.structured_id = (
|
||||
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
|
||||
)
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
adapter = DocumentIndexingBatchAdapter(
|
||||
db_session=db_session,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
tenant_id=tenant_id,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
|
||||
),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
document_batch=doc_batch_cleaned,
|
||||
request_id=index_attempt_md.request_id,
|
||||
adapter=adapter,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += index_pipeline_result.new_docs
|
||||
chunk_count += index_pipeline_result.total_chunks
|
||||
document_count += index_pipeline_result.total_docs
|
||||
|
||||
# resolve errors for documents that were successfully indexed
|
||||
failed_document_ids = [
|
||||
failure.failed_document.document_id
|
||||
for failure in index_pipeline_result.failures
|
||||
if failure.failed_document
|
||||
]
|
||||
successful_document_ids = [
|
||||
document.id
|
||||
for document in document_batch
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
)
|
||||
for error in doc_id_to_unresolved_errors[document_id]:
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
failure,
|
||||
db_session_temp,
|
||||
)
|
||||
|
||||
_check_failure_threshold(
|
||||
total_failures,
|
||||
document_count,
|
||||
batch_num,
|
||||
index_pipeline_result.failures[-1],
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=document_count,
|
||||
new_docs_indexed=net_doc_change,
|
||||
docs_removed_from_index=0,
|
||||
)
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# Add telemetry for indexing progress
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_PROGRESS,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"current_docs_indexed": document_count,
|
||||
"current_chunks_indexed": chunk_count,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# `make sure the checkpoints aren't getting too large`at some regular interval
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to validation error."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
if not index_attempt:
|
||||
# should always be set by now
|
||||
raise RuntimeError("Should never happen.")
|
||||
|
||||
VALIDATION_ERROR_THRESHOLD = 5
|
||||
|
||||
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
search_settings_id=index_attempt.search_settings_id,
|
||||
limit=VALIDATION_ERROR_THRESHOLD,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
num_validation_errors = len(
|
||||
[
|
||||
index_attempt
|
||||
for index_attempt in recent_index_attempts
|
||||
if index_attempt.error_msg
|
||||
and index_attempt.error_msg.startswith(
|
||||
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
|
||||
f" errors. Marking the CC Pair as invalid."
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
logger.exception(
|
||||
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
|
||||
)
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
memory_tracer.stop()
|
||||
|
||||
# we know index attempt is successful (at least partially) at this point,
|
||||
# all other cases have been short-circuited
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
error.is_resolved = True
|
||||
db_session_temp.add(error)
|
||||
db_session_temp.commit()
|
||||
|
||||
if total_failures == 0:
|
||||
mark_attempt_succeeded(index_attempt_id, db_session_temp)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session_temp,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
else:
|
||||
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
|
||||
logger.info(
|
||||
f"Connector completed with some errors: "
|
||||
f"failures={total_failures} "
|
||||
f"batches={batch_num} "
|
||||
f"docs={document_count} "
|
||||
f"chunks={chunk_count} "
|
||||
f"elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
run_dt=window_end,
|
||||
)
|
||||
if ctx.should_fetch_permissions_during_indexing:
|
||||
mark_cc_pair_as_permissions_synced(
|
||||
db_session=db_session_temp,
|
||||
cc_pair_id=ctx.cc_pair_id,
|
||||
start_time=window_end,
|
||||
)
|
||||
|
||||
|
||||
def run_docfetching_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
@@ -968,11 +368,19 @@ def connector_document_extraction(
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
has_successful_attempt = (
|
||||
index_attempt.connector_credential_pair.last_successful_index_time
|
||||
is not None
|
||||
)
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
docprocessing_priority = (
|
||||
OnyxCeleryPriority.MEDIUM
|
||||
if has_successful_attempt
|
||||
else OnyxCeleryPriority.HIGH
|
||||
)
|
||||
|
||||
earliest_index_time = (
|
||||
db_connector.indexing_start.timestamp()
|
||||
@@ -1095,6 +503,7 @@ def connector_document_extraction(
|
||||
tenant_id,
|
||||
app,
|
||||
most_recent_attempt,
|
||||
docprocessing_priority,
|
||||
)
|
||||
last_batch_num = reissued_batch_count + completed_batches
|
||||
index_attempt.completed_batches = completed_batches
|
||||
@@ -1207,7 +616,7 @@ def connector_document_extraction(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
@@ -1358,6 +767,7 @@ def reissue_old_batches(
|
||||
tenant_id: str,
|
||||
app: Celery,
|
||||
most_recent_attempt: IndexAttempt | None,
|
||||
priority: OnyxCeleryPriority,
|
||||
) -> tuple[int, int]:
|
||||
# When loading from a checkpoint, we need to start new docprocessing tasks
|
||||
# tied to the new index attempt for any batches left over in the file store
|
||||
@@ -1385,7 +795,7 @@ def reissue_old_batches(
|
||||
"batch_num": path_info.batch_num, # use same batch num as previously
|
||||
},
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=priority,
|
||||
)
|
||||
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
|
||||
# resume from the batch num of the last attempt. This should be one more
|
||||
|
||||
@@ -63,7 +63,7 @@ To ensure the LLM follows certain specific instructions, instructions are added
|
||||
tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the
|
||||
final message. If a search related tool just ran and the user has reminders, both appear in a single message.
|
||||
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent as responded.
|
||||
If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded.
|
||||
|
||||
|
||||
## Tool Calls
|
||||
@@ -145,9 +145,83 @@ attention despite having global access.
|
||||
In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
---
|
||||
|
||||
# Overview of LLM flow architecture
|
||||
|
||||
**Concepts:**
|
||||
Turn: User sends a message and AI does some set of things and responds
|
||||
Step/Cycle: 1 single LLM inference given some context and some tools
|
||||
|
||||
|
||||
## 1. Top Level (process_message function):
|
||||
This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the
|
||||
messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does
|
||||
are:
|
||||
- Validates the request
|
||||
- Builds the chat history for the session
|
||||
- Fetches any additional context such as files and images
|
||||
- Prepares all of the tools for the LLM
|
||||
- Creates the state container objects for use in the loop
|
||||
|
||||
### Wrapper (run_chat_loop_with_state_containers function):
|
||||
This wrapper is used to run the LLM flow in a background thread and monitor the emitter for stop signals. This means the top
|
||||
level is as isolated from the LLM flow as possible and can continue to yield packets as soon as they are available from the lower
|
||||
levels. This also means that if the lower levels fail, the top level will still guarantee a reasonable response to the user.
|
||||
All of the saving and database operations are abstracted away from the lower levels.
|
||||
|
||||
### Emitter
|
||||
The emitter is designed to be an object queue so that lower levels do not need to yield objects all the way back to the top.
|
||||
This way the functions can be better designed (not everything as a generator) and more easily tested. The wrapper around the
|
||||
LLM flow (run_chat_loop_with_state_containers) is used to monitor the emitter and handle packets as soon as they are available
|
||||
from the lower levels. Both the emitter and the state container are mutating state objects and only used to accumulate state.
|
||||
There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take
|
||||
packets and should not be used for other things.
|
||||
|
||||
### State Container
|
||||
The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic,
|
||||
only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database.
|
||||
So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once
|
||||
the lower level is completed whether on its own or stopped by the user. At that point, all of the state is read and stored into
|
||||
the database. The state container can be added to by any of the underlying layers, this is fine.
|
||||
|
||||
### Stopping Generation
|
||||
A stop signal is checked every 300ms by the wrapper around the LLM flow. The signal itself
|
||||
is stored in Redis and is set by the user calling the stop endpoint. The wrapper ensures that no matter what the lower level is
|
||||
doing at the time, the thread can be killed by the top level. It does not require a cooperative cancellation from the lower level
|
||||
and in fact the lower level does not know about the stop signal at all.
|
||||
|
||||
|
||||
## 2. LLM Loop (run_llm_loop function)
|
||||
This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what
|
||||
is outlined in the first half of this doc). Its main functionality is:
|
||||
- Translate and truncate the context for the LLM inference
|
||||
- Add context modifiers like reminders, updates to the system prompts, etc.
|
||||
- Run tool calls and gather results
|
||||
- Build some of the objects stored in the state container.
|
||||
|
||||
|
||||
## 3. LLM Step (run_llm_step function)
|
||||
This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations
|
||||
so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they
|
||||
do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different
|
||||
tool calls and returns that to the LLM Loop to execute.
|
||||
|
||||
|
||||
## Things to know
|
||||
- Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
|
||||
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
|
||||
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
|
||||
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
|
||||
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Module for handling chat-related milestone tracking and telemetry.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
|
||||
|
||||
def process_multi_assistant_milestone(
|
||||
user: User | None,
|
||||
assistant_id: int,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Process the multi-assistant milestone for a user.
|
||||
|
||||
This function:
|
||||
1. Creates or retrieves the multi-assistant milestone
|
||||
2. Updates the milestone with the current assistant usage
|
||||
3. Checks if the milestone was just achieved
|
||||
4. Sends telemetry if the milestone was just hit
|
||||
|
||||
Args:
|
||||
user: The user for whom to process the milestone (can be None for anonymous users)
|
||||
assistant_id: The ID of the assistant being used
|
||||
tenant_id: The current tenant ID
|
||||
db_session: Database session for queries
|
||||
"""
|
||||
# Create or retrieve the multi-assistant milestone
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update the milestone with the current assistant usage
|
||||
update_user_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
||||
assistant_id=assistant_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Check if the milestone was just achieved
|
||||
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Send telemetry if the milestone was just hit
|
||||
if just_hit_multi_assistant_milestone:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
properties=None,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user