mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 10:15:46 +00:00
Compare commits
88 Commits
overflow-d
...
dump-scrip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca3db17b08 | ||
|
|
ffd13b1104 | ||
|
|
1caa860f8e | ||
|
|
7181cc41af | ||
|
|
959b8c320d | ||
|
|
96fd0432ff | ||
|
|
4c73a03f57 | ||
|
|
e57713e376 | ||
|
|
21ea320323 | ||
|
|
bac9c48e53 | ||
|
|
7f79e34aa4 | ||
|
|
f1a81d45a1 | ||
|
|
285755a540 | ||
|
|
89003ad2d8 | ||
|
|
9f93f97259 | ||
|
|
f702eebbe7 | ||
|
|
8487e1856b | ||
|
|
a36445f840 | ||
|
|
7f30293b0e | ||
|
|
619d9528b4 | ||
|
|
6f83c669e7 | ||
|
|
c3e5f48cb4 | ||
|
|
fdf8fe391c | ||
|
|
f1d6bb9e02 | ||
|
|
9a64a717dc | ||
|
|
aa0f475e01 | ||
|
|
75238dc353 | ||
|
|
9e19803244 | ||
|
|
5cabd32638 | ||
|
|
4ccd88c331 | ||
|
|
5a80b98320 | ||
|
|
ff109d9f5c | ||
|
|
4cc276aca9 | ||
|
|
29f0df2c93 | ||
|
|
e2edcf0e0b | ||
|
|
9396fc547d | ||
|
|
c089903aad | ||
|
|
95471f64e9 | ||
|
|
13c1619d01 | ||
|
|
ddb5068847 | ||
|
|
81a4f654c2 | ||
|
|
9393c56a21 | ||
|
|
1ee96ff99c | ||
|
|
6bb00d2c6b | ||
|
|
d9cc923c6a | ||
|
|
bfbba0f036 | ||
|
|
ccf6911f97 | ||
|
|
15c9c2ba8e | ||
|
|
8b3fedf480 | ||
|
|
b8dc0749ee | ||
|
|
d6426458c6 | ||
|
|
941c4d6a54 | ||
|
|
653b65da66 | ||
|
|
503e70be02 | ||
|
|
9c19493160 | ||
|
|
933315646b | ||
|
|
d2061f8a26 | ||
|
|
6a98f0bf3c | ||
|
|
2f4d39d834 | ||
|
|
40f8bcc6f8 | ||
|
|
af9ed73f00 | ||
|
|
bf28041f4e | ||
|
|
395d5927b7 | ||
|
|
c96f24e37c | ||
|
|
070519f823 | ||
|
|
a7dc1c0f3b | ||
|
|
a947e44926 | ||
|
|
a6575b6254 | ||
|
|
31733a9c7c | ||
|
|
5415e2faf1 | ||
|
|
749f720dfd | ||
|
|
eac79cfdf2 | ||
|
|
e3b1202731 | ||
|
|
6df13cc2de | ||
|
|
682f660aa3 | ||
|
|
c4670ea86c | ||
|
|
a6757eb49f | ||
|
|
cd372fb585 | ||
|
|
45fa0d9b32 | ||
|
|
45091f2ee2 | ||
|
|
43a3cb89b9 | ||
|
|
9428eaed8d | ||
|
|
dd29d989ff | ||
|
|
f44daa2116 | ||
|
|
212cbcb683 | ||
|
|
aaad573c3f | ||
|
|
e1325e84ae | ||
|
|
e759cdd4ab |
33
.github/workflows/check-lazy-imports.yml
vendored
33
.github/workflows/check-lazy-imports.yml
vendored
@@ -1,33 +0,0 @@
|
||||
name: Check Lazy Imports
|
||||
concurrency:
|
||||
group: Check-Lazy-Imports-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
25
.github/workflows/deployment.yml
vendored
25
.github/workflows/deployment.yml
vendored
@@ -89,9 +89,10 @@ jobs:
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
@@ -111,7 +112,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -140,7 +141,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -198,7 +199,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -306,7 +307,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -372,7 +373,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -485,7 +486,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -542,7 +543,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -650,7 +651,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -714,7 +715,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -907,7 +908,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -997,7 +998,7 @@ jobs:
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -80,12 +80,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -113,6 +114,7 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
41
.github/workflows/pr-integration-tests.yml
vendored
41
.github/workflows/pr-integration-tests.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -129,7 +129,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -183,7 +183,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -259,7 +259,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -274,23 +274,28 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -436,7 +441,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/pr-jest-tests.yml
vendored
4
.github/workflows/pr-jest-tests.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
35
.github/workflows/pr-mit-integration-tests.yml
vendored
35
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -177,7 +177,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -268,21 +268,26 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
|
||||
14
.github/workflows/pr-playwright-tests.yml
vendored
14
.github/workflows/pr-playwright-tests.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -229,13 +229,13 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'npm'
|
||||
@@ -465,12 +465,12 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
# uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
|
||||
31
.github/workflows/pr-python-checks.yml
vendored
31
.github/workflows/pr-python-checks.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -40,35 +40,10 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
ods openapi all
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
|
||||
12
.github/workflows/pr-python-connector-tests.yml
vendored
12
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -133,12 +133,13 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -160,16 +161,20 @@ jobs:
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
- 'uv.lock'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
- 'uv.lock'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
- 'uv.lock'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -182,7 +187,8 @@ jobs:
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
--ignore backend/tests/daily/connectors/github \
|
||||
--ignore backend/tests/daily/connectors/coda
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
6
.github/workflows/pr-python-tests.yml
vendored
6
.github/workflows/pr-python-tests.yml
vendored
@@ -26,15 +26,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
@@ -39,7 +41,7 @@ jobs:
|
||||
- uses: j178/prek-action@91fd7d7cf70ae1dee9f4f44e7dfa5d1073fe6623 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
with:
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
- {goos: "darwin", goarch: "arm64"}
|
||||
- {goos: "", goarch: ""}
|
||||
steps:
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
2
.github/workflows/sync_foss.yml
vendored
2
.github/workflows/sync_foss.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tag-nightly.yml
vendored
2
.github/workflows/tag-nightly.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -53,3 +53,6 @@ node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
@@ -5,8 +5,13 @@ default_install_hook_types:
|
||||
- post-rewrite
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 569ddf04117761eb74cef7afb5143bbb96fcdfbb # frozen: 0.9.15
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
@@ -14,19 +19,19 @@ repos:
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
@@ -71,7 +76,7 @@ repos:
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: e6ebea0145f385056bce15041d3244c0e5e15848 # frozen: v2.7.0
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
@@ -107,12 +112,6 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
|
||||
1
.vscode/launch.template.jsonc
vendored
1
.vscode/launch.template.jsonc
vendored
@@ -508,7 +508,6 @@
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
|
||||
@@ -7,8 +7,12 @@ Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
From onyx/backend, run:
|
||||
`alembic revision -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
Note: you cannot use the `--autogenerate` flag as the automatic schema parsing does not work.
|
||||
|
||||
Manually populate the upgrade and downgrade in your new migration.
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add is_clarification to chat_message
|
||||
|
||||
Revision ID: 18b5b2524446
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-01-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "18b5b2524446"
|
||||
down_revision = "87c52ec39f84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"is_clarification", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_clarification")
|
||||
@@ -0,0 +1,62 @@
|
||||
"""update_default_tool_descriptions
|
||||
|
||||
Revision ID: a01bf2971c5d
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-12-16 15:21:25.656375
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a01bf2971c5d"
|
||||
down_revision = "18b5b2524446"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# new tool descriptions (12/2025)
|
||||
TOOL_DESCRIPTIONS = {
|
||||
"SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.",
|
||||
"ImageGenerationTool": (
|
||||
"The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"OktaProfileTool": (
|
||||
"The Okta Profile Action allows the agent to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -362,14 +363,29 @@ def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
|
||||
def _add_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, user_ids: list[UUID]
|
||||
) -> list[User__UserGroup]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
|
||||
for user_id in user_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction.
|
||||
|
||||
This function is idempotent - it will skip users who are already in the group
|
||||
to avoid duplicate key violations during concurrent operations or re-syncs.
|
||||
Uses ON CONFLICT DO NOTHING to keep inserts atomic under concurrency.
|
||||
"""
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
insert_stmt = (
|
||||
insert(User__UserGroup)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "user_group_id": user_group_id}
|
||||
for user_id in user_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id]
|
||||
)
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
|
||||
def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
@@ -219,7 +219,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email)
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
@@ -227,7 +227,9 @@ def verify_email_is_invited(email: str) -> None:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
email_info_whitelist = validate_email(
|
||||
email_whitelist, check_deliverability=False
|
||||
)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
|
||||
@@ -105,52 +105,49 @@ S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
|
||||
- Reminder moved to the end
|
||||
```
|
||||
|
||||
|
||||
## Product considerations
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
User uploaded files are considered relevant for that point in time, it is ok if the Agent forgets about it as the chat gets long. If every uploaded file is
|
||||
constantly moved towards the end of the chat, it would degrade quality as these stack up. Even with a single file, there is some cost of making the previous
|
||||
User Message further away. This tradeoff is accepted for Projects because of the intent of the feature.
|
||||
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably and also not interfere with the last user message.
|
||||
|
||||
|
||||
## Reasons / Experiments
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
|
||||
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
|
||||
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
|
||||
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
|
||||
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably.
|
||||
|
||||
User uploaded files are considered relevant for that point in time, it is ok if the Agent forgets about it as the chat gets long. If every uploaded file is
|
||||
constantly moved towards the end of the chat, it would degrade quality as these stack up. Even with a single file, there is some cost of making the previous
|
||||
User Message further away. This tradeoff is accepted for Projects because of the intent of the feature.
|
||||
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in db/models.py
|
||||
|
||||
|
||||
# Appendix (just random tidbits for those interested)
|
||||
- Reminder messages are placed at the end of the prompt because all model fine tuning approaches cause the LLMs to attend very strongly to the tokens at the very
|
||||
back of the context closest to generation. This is the only way to get the LLMs to not miss critical information and for the product to be reliable. Specifically
|
||||
the built-in reminders are around citations and what tools it should call in certain situations.
|
||||
|
||||
- LLMs are able to handle changes in topic best at message boundaries. There are special tokens under the hood for this. We also use this property to slice up
|
||||
the history in the way presented above.
|
||||
|
||||
- Different LLMs vary in this but some now have a section that cannot be set via the API layer called the "System Prompt" (OpenAI terminology) which contains
|
||||
Different LLMs vary in this but some now have a section that cannot be set via the API layer called the "System Prompt" (OpenAI terminology) which contains
|
||||
information like the model cutoff date, identity, and some other basic non-changing information. The System prompt described above is in that convention called
|
||||
the "Developer Prompt". It seems the distribution of the System Prompt, by which I mean the style of wording and terms used can also affect the behavior. This
|
||||
is different between different models and not necessarily scientific so the system prompt is built from an exploration across different models. It currently
|
||||
starts with: "You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent..."
|
||||
|
||||
- The document json includes a field for the LLM to cite (it's a single number) to make citations reliable and avoid weird artifacts. It's called "document" so
|
||||
LLMs are able to handle changes in topic best at message boundaries. There are special tokens under the hood for this. We also use this property to slice up
|
||||
the history in the way presented above.
|
||||
|
||||
Reminder messages are placed at the end of the prompt because all model fine tuning approaches cause the LLMs to attend very strongly to the tokens at the very
|
||||
back of the context closest to generation. This is the only way to get the LLMs to not miss critical information and for the product to be reliable. Specifically
|
||||
the built-in reminders are around citations and what tools it should call in certain situations.
|
||||
|
||||
The document json includes a field for the LLM to cite (it's a single number) to make citations reliable and avoid weird artifacts. It's called "document" so
|
||||
that the LLM does not create weird artifacts in reasoning like "I should reference citation_id: 5 for...". It is also strategically placed so that it is easy to
|
||||
reference. It is followed by a couple short sections like the metadata and title before the long content section. It seems LLMs are still better at local
|
||||
attention despite having global access.
|
||||
|
||||
- In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
- Custom Agent prompts are also completely separate from the system prompt. Having potentially orthogonal instructions in the system prompt (both the actual
|
||||
instructions and the writing style) can greatly deteriorate the quality of the responses. There is also a product motivation to keep it close to the end of
|
||||
generation so it's strongly followed.
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
@@ -26,6 +26,8 @@ class ChatStateContainer:
|
||||
self.answer_tokens: str | None = None
|
||||
# Store citation mapping for building citation_docs_info during partial saves
|
||||
self.citation_to_doc: dict[int, SearchDoc] = {}
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -43,6 +45,10 @@ class ChatStateContainer:
|
||||
"""Set the citation mapping from citation processor."""
|
||||
self.citation_to_doc = citation_to_doc
|
||||
|
||||
def set_is_clarification(self, is_clarification: bool) -> None:
|
||||
"""Set whether this turn is a clarification question."""
|
||||
self.is_clarification = is_clarification
|
||||
|
||||
|
||||
def run_chat_llm_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
|
||||
@@ -477,7 +477,10 @@ def load_chat_file(
|
||||
|
||||
# Extract text content if it's a text file type (not an image)
|
||||
content_text = None
|
||||
file_type = file_descriptor["type"]
|
||||
# `FileDescriptor` is often JSON-roundtripped (e.g. JSONB / API), so `type`
|
||||
# may arrive as a raw string value instead of a `ChatFileType`.
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = content.decode("utf-8")
|
||||
@@ -708,3 +711,21 @@ def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str
|
||||
return chat_session.project.instructions
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:
|
||||
"""Check if the last assistant message in chat history was a clarification question.
|
||||
|
||||
This is used in the deep research flow to determine whether to skip the
|
||||
clarification step when the user has already responded to a clarification.
|
||||
|
||||
Args:
|
||||
chat_history: List of ChatMessage objects in chronological order
|
||||
|
||||
Returns:
|
||||
True if the last assistant message has is_clarification=True, False otherwise
|
||||
"""
|
||||
for message in reversed(chat_history):
|
||||
if message.message_type == MessageType.ASSISTANT:
|
||||
return message.is_clarification
|
||||
return False
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
@@ -103,15 +104,23 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
if last_n_user_messages is not None:
|
||||
if last_n_user_messages <= 0:
|
||||
raise ValueError(
|
||||
"filtering chat history by last N user messages must be a value greater than 0"
|
||||
)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
history_token_budget -= system_prompt.token_count
|
||||
history_token_budget -= (
|
||||
custom_agent_prompt.token_count if custom_agent_prompt else 0
|
||||
)
|
||||
history_token_budget -= project_files.total_token_count
|
||||
if project_files:
|
||||
history_token_budget -= project_files.total_token_count
|
||||
history_token_budget -= reminder_message.token_count if reminder_message else 0
|
||||
|
||||
if history_token_budget < 0:
|
||||
@@ -122,7 +131,7 @@ def construct_message_history(
|
||||
result = [system_prompt]
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
if project_files.project_file_texts:
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
@@ -131,6 +140,26 @@ def construct_message_history(
|
||||
result.append(reminder_message)
|
||||
return result
|
||||
|
||||
# If last_n_user_messages is set, filter history to only include the last n user messages
|
||||
if last_n_user_messages is not None:
|
||||
# Find all user message indices
|
||||
user_msg_indices = [
|
||||
i
|
||||
for i, msg in enumerate(simple_chat_history)
|
||||
if msg.message_type == MessageType.USER
|
||||
]
|
||||
|
||||
if not user_msg_indices:
|
||||
raise ValueError("No user message found in simple_chat_history")
|
||||
|
||||
# If we have more than n user messages, keep only the last n
|
||||
if len(user_msg_indices) > last_n_user_messages:
|
||||
# Find the index of the n-th user message from the end
|
||||
# For example, if last_n_user_messages=2, we want the 2nd-to-last user message
|
||||
nth_user_msg_index = user_msg_indices[-(last_n_user_messages)]
|
||||
# Keep everything from that user message onwards
|
||||
simple_chat_history = simple_chat_history[nth_user_msg_index:]
|
||||
|
||||
# Find the last USER message in the history
|
||||
# The history may contain tool calls and responses after the last user message
|
||||
last_user_msg_index = None
|
||||
@@ -178,7 +207,7 @@ def construct_message_history(
|
||||
break
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files.project_image_files:
|
||||
if project_files and project_files.project_image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
@@ -200,7 +229,7 @@ def construct_message_history(
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files message (inserted before last user message)
|
||||
if project_files.project_file_texts:
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
@@ -263,6 +292,7 @@ def run_llm_loop(
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> None:
|
||||
with trace("run_llm_loop", metadata={"tenant_id": get_current_tenant_id()}):
|
||||
# Fix some LiteLLM issues,
|
||||
@@ -310,6 +340,7 @@ def run_llm_loop(
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
current_tool_call_index = (
|
||||
@@ -426,6 +457,7 @@ def run_llm_loop(
|
||||
# immediately yield the full set of found documents. This gives us the option to show the
|
||||
# final set of documents immediately if desired.
|
||||
final_documents=gathered_documents,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
@@ -460,8 +492,13 @@ def run_llm_loop(
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=citation_processor,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
|
||||
# Track if search tool was called (for skipping query expansion on subsequent calls)
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
@@ -332,6 +333,7 @@ def run_llm_step(
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
state_container: ChatStateContainer,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
@@ -365,6 +367,7 @@ def run_llm_step(
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=None, # TODO
|
||||
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
|
||||
user_identity=user_identity,
|
||||
):
|
||||
if packet.usage:
|
||||
usage = packet.usage
|
||||
|
||||
@@ -102,6 +102,11 @@ class MessageResponseIDInfo(BaseModel):
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
error_code: str | None = (
|
||||
None # e.g., "RATE_LIMIT", "AUTH_ERROR", "TOOL_CALL_FAILED"
|
||||
)
|
||||
is_retryable: bool = True # Hint to frontend if retry might help
|
||||
details: dict | None = None # Additional context (tool name, model name, etc.)
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.chat.chat_state import run_chat_llm_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
from onyx.chat.chat_utils import is_last_assistant_message_clarification
|
||||
from onyx.chat.chat_utils import load_all_chat_files
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -53,6 +54,7 @@ from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -62,10 +64,12 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -79,6 +83,10 @@ ERROR_TYPE_CANCELLED = "cancelled"
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
@@ -206,6 +214,46 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
has_project_file_texts: bool,
|
||||
forced_tool_ids: list[int] | None,
|
||||
search_tool_id: int | None,
|
||||
) -> SearchToolUsage:
|
||||
"""Determine search tool availability based on project context.
|
||||
|
||||
Args:
|
||||
project_id: The project ID if the user is in a project
|
||||
persona_id: The persona ID to check if it's the default persona
|
||||
has_project_file_texts: Whether project files are loaded in context
|
||||
forced_tool_ids: List of forced tool IDs (may be mutated to remove search tool)
|
||||
search_tool_id: The search tool ID to check against
|
||||
|
||||
Returns:
|
||||
SearchToolUsage setting indicating how search should be used
|
||||
"""
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
# If no files are uploaded, search should remain enabled
|
||||
search_usage_forcing_setting = SearchToolUsage.AUTO
|
||||
if project_id:
|
||||
if bool(persona_id is DEFAULT_PERSONA_ID and has_project_file_texts):
|
||||
search_usage_forcing_setting = SearchToolUsage.DISABLED
|
||||
# Remove search tool from forced_tool_ids if it's present
|
||||
if forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
forced_tool_ids[:] = [
|
||||
tool_id for tool_id in forced_tool_ids if tool_id != search_tool_id
|
||||
]
|
||||
elif forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
search_usage_forcing_setting = SearchToolUsage.ENABLED
|
||||
return search_usage_forcing_setting
|
||||
|
||||
|
||||
def _initialize_chat_session(
|
||||
message_text: str,
|
||||
files: list[FileDescriptor],
|
||||
@@ -285,10 +333,15 @@ def stream_chat_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
|
||||
llm: LLM
|
||||
llm: LLM | None = None
|
||||
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
user.email
|
||||
if user is not None and getattr(user, "email", None)
|
||||
else (str(user_id) if user_id else "anonymous_user")
|
||||
)
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
@@ -299,6 +352,9 @@ def stream_chat_message_objects(
|
||||
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session_id)
|
||||
)
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
@@ -391,19 +447,23 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
disable_internal_search = bool(
|
||||
chat_session.project_id
|
||||
and persona.id is DEFAULT_PERSONA_ID
|
||||
and (
|
||||
extracted_project_files.project_file_texts
|
||||
or not extracted_project_files.project_as_filter
|
||||
)
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
# This may also mutate the new_msg_req.forced_tool_ids
|
||||
# This logic is specifically for projects
|
||||
search_usage_forcing_setting = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
has_project_file_texts=bool(extracted_project_files.project_file_texts),
|
||||
forced_tool_ids=new_msg_req.forced_tool_ids,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
|
||||
emitter = get_default_emitter()
|
||||
@@ -432,7 +492,7 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
disable_internal_search=disable_internal_search,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -457,10 +517,6 @@ def stream_chat_message_objects(
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
simple_chat_history = convert_chat_history(
|
||||
@@ -491,6 +547,13 @@ def stream_chat_message_objects(
|
||||
# Note: DB session is not thread safe but nothing else uses it and the
|
||||
# reference is passed directly so it's ok.
|
||||
if os.environ.get("ENABLE_DEEP_RESEARCH_LOOP"): # Dev only feature flag for now
|
||||
if chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
|
||||
# Skip clarification if the last assistant message was a clarification
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
is_connected=check_is_connected,
|
||||
@@ -502,6 +565,8 @@ def stream_chat_message_objects(
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
@@ -523,6 +588,7 @@ def stream_chat_message_objects(
|
||||
if new_msg_req.forced_tool_ids
|
||||
else None
|
||||
),
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
@@ -567,13 +633,18 @@ def stream_chat_message_objects(
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
@@ -583,9 +654,17 @@ def stream_chat_message_objects(
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
client_error_msg = client_error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
@@ -594,7 +673,24 @@ def stream_chat_message_objects(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(
|
||||
error=client_error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code=error_code,
|
||||
is_retryable=is_retryable,
|
||||
details={
|
||||
"model": llm.config.model_name,
|
||||
"provider": llm.config.model_provider,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# LLM was never initialized - early failure
|
||||
yield StreamingError(
|
||||
error="Failed to initialize the chat. Please check your configuration and try again.",
|
||||
stack_trace=stack_trace,
|
||||
error_code="INIT_FAILED",
|
||||
is_retryable=True,
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
@@ -148,6 +148,7 @@ def save_chat_turn(
|
||||
citation_docs_info: list[CitationDocInfo],
|
||||
db_session: Session,
|
||||
assistant_message: ChatMessage,
|
||||
is_clarification: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save a chat turn by populating the assistant_message and creating related entities.
|
||||
@@ -175,10 +176,12 @@ def save_chat_turn(
|
||||
citation_docs_info: List of citation document information for building citations mapping
|
||||
db_session: Database session for persistence
|
||||
assistant_message: The ChatMessage object to populate (should already exist in DB)
|
||||
is_clarification: Whether this assistant message is a clarification question (deep research flow)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
|
||||
@@ -7,6 +7,7 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 24 * 60 * 60 # 24 hours - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
@@ -24,7 +25,7 @@ def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0)
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
|
||||
@@ -24,6 +24,12 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
# Whether to send user metadata (user_id/email and session_id) to the LLM provider.
|
||||
# Disabled by default.
|
||||
SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
os.environ.get("SEND_USER_METADATA_TO_LLM_PROVIDER", "")
|
||||
).lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
|
||||
@@ -177,6 +177,7 @@ class DocumentSource(str, Enum):
|
||||
SLAB = "slab"
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -596,6 +597,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
DocumentSource.ZULIP: "zulip data",
|
||||
|
||||
@@ -65,9 +65,10 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
# Fallback token limit for models where the max context is unknown
|
||||
# Set conservatively at 32K to handle most modern models
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 32000
|
||||
)
|
||||
|
||||
# This is used when computing how much context space is available for documents
|
||||
|
||||
@@ -97,28 +97,31 @@ class AsanaAPI:
|
||||
self, project_gid: str, start_date: str, start_seconds: int
|
||||
) -> Iterator[AsanaTask]:
|
||||
project = self.project_api.get_project(project_gid, opts={})
|
||||
if project["archived"]:
|
||||
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
|
||||
yield from []
|
||||
if not project["team"] or not project["team"]["gid"]:
|
||||
project_name = project.get("name", project_gid)
|
||||
team = project.get("team") or {}
|
||||
team_gid = team.get("gid")
|
||||
|
||||
if project.get("archived"):
|
||||
logger.info(f"Skipping archived project: {project_name} ({project_gid})")
|
||||
return
|
||||
if not team_gid:
|
||||
logger.info(
|
||||
f"Skipping project without a team: {project['name']} ({project_gid})"
|
||||
f"Skipping project without a team: {project_name} ({project_gid})"
|
||||
)
|
||||
yield from []
|
||||
if project["privacy_setting"] == "private":
|
||||
if self.team_gid and project["team"]["gid"] != self.team_gid:
|
||||
return
|
||||
if project.get("privacy_setting") == "private":
|
||||
if self.team_gid and team_gid != self.team_gid:
|
||||
logger.info(
|
||||
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
yield from []
|
||||
else:
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project['name']} ({project_gid})"
|
||||
f"Skipping private project not in configured team: {project_name} ({project_gid})"
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project_name} ({project_gid})"
|
||||
)
|
||||
|
||||
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||
logger.info(
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})"
|
||||
)
|
||||
|
||||
opts = {
|
||||
@@ -157,7 +160,7 @@ class AsanaAPI:
|
||||
link=data["permalink_url"],
|
||||
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||
project_gid=project_gid,
|
||||
project_name=project["name"],
|
||||
project_name=project_name,
|
||||
)
|
||||
yield task
|
||||
except Exception:
|
||||
|
||||
0
backend/onyx/connectors/coda/__init__.py
Normal file
0
backend/onyx/connectors/coda/__init__.py
Normal file
711
backend/onyx/connectors/coda/connector.py
Normal file
711
backend/onyx/connectors/coda/connector.py
Normal file
@@ -0,0 +1,711 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
_CODA_CALL_TIMEOUT = 30
|
||||
_CODA_BASE_URL = "https://coda.io/apis/v1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CodaClientRequestFailedError(ConnectionError):
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(
|
||||
f"Coda API request failed with status {status_code}: {message}"
|
||||
)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class CodaDoc(BaseModel):
|
||||
id: str
|
||||
browser_link: str
|
||||
name: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
workspace_id: str
|
||||
workspace_name: str
|
||||
folder_id: str | None
|
||||
folder_name: str | None
|
||||
|
||||
|
||||
class CodaPage(BaseModel):
|
||||
id: str
|
||||
browser_link: str
|
||||
name: str
|
||||
content_type: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaTable(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
browser_link: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaRow(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
browser_link: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
values: Dict[str, Any]
|
||||
table_id: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
) -> None:
|
||||
self.bearer_token = bearer_token
|
||||
self.base_url = os.environ.get("CODA_BASE_URL", _CODA_BASE_URL)
|
||||
|
||||
def get(
|
||||
self, endpoint: str, params: Optional[dict[str, str]] = None
|
||||
) -> dict[str, Any]:
|
||||
url = self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url, headers=headers, params=params, timeout=_CODA_CALL_TIMEOUT
|
||||
)
|
||||
|
||||
try:
|
||||
json = response.json()
|
||||
except Exception:
|
||||
json = {}
|
||||
|
||||
if response.status_code >= 300:
|
||||
error = response.reason
|
||||
response_error = json.get("error", {}).get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
raise CodaClientRequestFailedError(error, response.status_code)
|
||||
|
||||
return json
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self.bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
return self.base_url.rstrip("/") + "/" + endpoint.lstrip("/")
|
||||
|
||||
|
||||
class CodaConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
index_page_content: bool = True,
|
||||
workspace_id: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.index_page_content = index_page_content
|
||||
self.workspace_id = workspace_id
|
||||
self._coda_client: CodaApiClient | None = None
|
||||
|
||||
@property
|
||||
def coda_client(self) -> CodaApiClient:
|
||||
if self._coda_client is None:
|
||||
raise ConnectorMissingCredentialError("Coda")
|
||||
return self._coda_client
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_doc(self, doc_id: str) -> CodaDoc:
|
||||
"""Fetch a specific Coda document by its ID."""
|
||||
logger.debug(f"Fetching Coda doc with ID: {doc_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(f"Failed to fetch doc: {doc_id}") from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaDoc(
|
||||
id=response["id"],
|
||||
browser_link=response["browserLink"],
|
||||
name=response["name"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
workspace_id=response["workspace"]["id"],
|
||||
workspace_name=response["workspace"]["name"],
|
||||
folder_id=response["folder"]["id"] if response.get("folder") else None,
|
||||
folder_name=response["folder"]["name"] if response.get("folder") else None,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_page(self, doc_id: str, page_id: str) -> CodaPage:
|
||||
"""Fetch a specific page from a Coda document."""
|
||||
logger.debug(f"Fetching Coda page with ID: {page_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}/pages/{page_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch page: {page_id} from doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaPage(
|
||||
id=response["id"],
|
||||
doc_id=doc_id,
|
||||
browser_link=response["browserLink"],
|
||||
name=response["name"],
|
||||
content_type=response["contentType"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_table(self, doc_id: str, table_id: str) -> CodaTable:
|
||||
"""Fetch a specific table from a Coda document."""
|
||||
logger.debug(f"Fetching Coda table with ID: {table_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}/tables/{table_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch table: {table_id} from doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaTable(
|
||||
id=response["id"],
|
||||
name=response["name"],
|
||||
browser_link=response["browserLink"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_row(self, doc_id: str, table_id: str, row_id: str) -> CodaRow:
|
||||
"""Fetch a specific row from a Coda table."""
|
||||
logger.debug(f"Fetching Coda row with ID: {row_id}")
|
||||
try:
|
||||
response = self.coda_client.get(
|
||||
f"docs/{doc_id}/tables/{table_id}/rows/{row_id}"
|
||||
)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch row: {row_id} from table: {table_id} in doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
values = {}
|
||||
for col_name, col_value in response.get("values", {}).items():
|
||||
values[col_name] = col_value
|
||||
|
||||
return CodaRow(
|
||||
id=response["id"],
|
||||
name=response.get("name"),
|
||||
index=response.get("index"),
|
||||
browser_link=response["browserLink"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
values=values,
|
||||
table_id=table_id,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_all_docs(
|
||||
self, endpoint: str = "docs", params: Optional[Dict[str, str]] = None
|
||||
) -> List[CodaDoc]:
|
||||
"""List all Coda documents in the workspace."""
|
||||
logger.debug("Listing documents in Coda")
|
||||
|
||||
all_docs: List[CodaDoc] = []
|
||||
next_page_token: str | None = None
|
||||
params = params or {}
|
||||
|
||||
if self.workspace_id:
|
||||
params["workspaceId"] = self.workspace_id
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError("Failed to list docs") from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
|
||||
for item in items:
|
||||
doc = CodaDoc(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
workspace_id=item["workspace"]["id"],
|
||||
workspace_name=item["workspace"]["name"],
|
||||
folder_id=item["folder"]["id"] if item.get("folder") else None,
|
||||
folder_name=item["folder"]["name"] if item.get("folder") else None,
|
||||
)
|
||||
all_docs.append(doc)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(all_docs)} docs")
|
||||
return all_docs
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages_in_doc(self, doc_id: str) -> List[CodaPage]:
|
||||
"""List all pages in a Coda document."""
|
||||
logger.debug(f"Listing pages in Coda doc with ID: {doc_id}")
|
||||
|
||||
pages: List[CodaPage] = []
|
||||
endpoint = f"docs/{doc_id}/pages"
|
||||
params: Dict[str, str] = {}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list pages for doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
# can be removed if we don't care to skip hidden pages
|
||||
if item.get("isHidden", False):
|
||||
continue
|
||||
|
||||
pages.append(
|
||||
CodaPage(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
content_type=item["contentType"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(pages)} pages in doc {doc_id}")
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page_content(self, doc_id: str, page_id: str) -> str:
|
||||
"""Fetch the content of a Coda page."""
|
||||
logger.debug(f"Fetching content for page {page_id} in doc {doc_id}")
|
||||
|
||||
content_parts = []
|
||||
next_page_token: str | None = None
|
||||
params: Dict[str, str] = {}
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(
|
||||
f"docs/{doc_id}/pages/{page_id}/content", params=params
|
||||
)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
logger.debug(f"No content available for page {page_id}")
|
||||
return ""
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
|
||||
for item in items:
|
||||
item_content = item.get("itemContent", {})
|
||||
|
||||
content_text = item_content.get("content", "")
|
||||
if content_text:
|
||||
content_parts.append(content_text)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_tables(self, doc_id: str) -> List[CodaTable]:
|
||||
"""List all tables in a Coda document."""
|
||||
logger.debug(f"Listing tables in Coda doc with ID: {doc_id}")
|
||||
|
||||
tables: List[CodaTable] = []
|
||||
endpoint = f"docs/{doc_id}/tables"
|
||||
params: Dict[str, str] = {}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list tables for doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
tables.append(
|
||||
CodaTable(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(tables)} tables in doc {doc_id}")
|
||||
return tables
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_rows_and_values(self, doc_id: str, table_id: str) -> List[CodaRow]:
|
||||
"""List all rows and their values in a table."""
|
||||
logger.debug(f"Listing rows in Coda table: {table_id} in Coda doc: {doc_id}")
|
||||
|
||||
rows: List[CodaRow] = []
|
||||
endpoint = f"docs/{doc_id}/tables/{table_id}/rows"
|
||||
params: Dict[str, str] = {"valueFormat": "rich"}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list rows for table: {table_id} in doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
values = {}
|
||||
for col_name, col_value in item.get("values", {}).items():
|
||||
values[col_name] = col_value
|
||||
|
||||
rows.append(
|
||||
CodaRow(
|
||||
id=item["id"],
|
||||
name=item["name"],
|
||||
index=item["index"],
|
||||
browser_link=item["browserLink"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
values=values,
|
||||
table_id=table_id,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(rows)} rows in table {table_id}")
|
||||
return rows
|
||||
|
||||
def _convert_page_to_document(self, page: CodaPage, content: str = "") -> Document:
|
||||
"""Convert a page into a Document."""
|
||||
page_updated = datetime.fromisoformat(page.updated_at).astimezone(timezone.utc)
|
||||
|
||||
text_parts = [page.name, page.browser_link]
|
||||
if content:
|
||||
text_parts.append(content)
|
||||
|
||||
sections = [TextSection(link=page.browser_link, text="\n\n".join(text_parts))]
|
||||
|
||||
return Document(
|
||||
id=f"coda-page-{page.doc_id}-{page.id}",
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.CODA,
|
||||
semantic_identifier=page.name or f"Page {page.id}",
|
||||
doc_updated_at=page_updated,
|
||||
metadata={
|
||||
"browser_link": page.browser_link,
|
||||
"doc_id": page.doc_id,
|
||||
"content_type": page.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
def _convert_table_with_rows_to_document(
|
||||
self, table: CodaTable, rows: List[CodaRow]
|
||||
) -> Document:
|
||||
"""Convert a table and its rows into a single Document with multiple sections (one per row)."""
|
||||
table_updated = datetime.fromisoformat(table.updated_at).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
sections: List[TextSection] = []
|
||||
for row in rows:
|
||||
content_text = " ".join(
|
||||
str(v) if not isinstance(v, list) else " ".join(map(str, v))
|
||||
for v in row.values.values()
|
||||
)
|
||||
|
||||
row_name = row.name or f"Row {row.index or row.id}"
|
||||
text = f"{row_name}: {content_text}" if content_text else row_name
|
||||
|
||||
sections.append(TextSection(link=row.browser_link, text=text))
|
||||
|
||||
# If no rows, create a single section for the table itself
|
||||
if not sections:
|
||||
sections = [
|
||||
TextSection(link=table.browser_link, text=f"Table: {table.name}")
|
||||
]
|
||||
|
||||
return Document(
|
||||
id=f"coda-table-{table.doc_id}-{table.id}",
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.CODA,
|
||||
semantic_identifier=table.name or f"Table {table.id}",
|
||||
doc_updated_at=table_updated,
|
||||
metadata={
|
||||
"browser_link": table.browser_link,
|
||||
"doc_id": table.doc_id,
|
||||
"row_count": str(len(rows)),
|
||||
},
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Coda credentials."""
|
||||
self._coda_client = CodaApiClient(bearer_token=credentials["coda_bearer_token"])
|
||||
|
||||
try:
|
||||
self._coda_client.get("docs", params={"limit": "1"})
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 401:
|
||||
raise ConnectorMissingCredentialError("Invalid Coda API token")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from Coda workspace."""
|
||||
|
||||
def _iter_documents() -> Generator[Document, None, None]:
|
||||
docs = self._list_all_docs()
|
||||
logger.info(f"Found {len(docs)} Coda docs to process")
|
||||
|
||||
for doc in docs:
|
||||
logger.debug(f"Processing doc: {doc.name} ({doc.id})")
|
||||
|
||||
try:
|
||||
pages = self._list_pages_in_doc(doc.id)
|
||||
for page in pages:
|
||||
content = ""
|
||||
if self.index_page_content:
|
||||
try:
|
||||
content = self._fetch_page_content(doc.id, page.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch content for page {page.id}: {e}"
|
||||
)
|
||||
yield self._convert_page_to_document(page, content)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list pages for doc {doc.id}: {e}")
|
||||
|
||||
try:
|
||||
tables = self._list_tables(doc.id)
|
||||
for table in tables:
|
||||
try:
|
||||
rows = self._list_rows_and_values(doc.id, table.id)
|
||||
yield self._convert_table_with_rows_to_document(table, rows)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(
|
||||
f"Failed to list rows for table {table.id}: {e}"
|
||||
)
|
||||
yield self._convert_table_with_rows_to_document(table, [])
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list tables for doc {doc.id}: {e}")
|
||||
|
||||
return batch_generator(_iter_documents(), self.batch_size)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Polls the Coda API for documents updated between start and end timestamps.
|
||||
We refer to page and table update times to determine if they need to be re-indexed.
|
||||
"""
|
||||
|
||||
def _iter_documents() -> Generator[Document, None, None]:
|
||||
docs = self._list_all_docs()
|
||||
logger.info(
|
||||
f"Polling {len(docs)} Coda docs for updates between {start} and {end}"
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
try:
|
||||
pages = self._list_pages_in_doc(doc.id)
|
||||
for page in pages:
|
||||
page_timestamp = (
|
||||
datetime.fromisoformat(page.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
if start < page_timestamp <= end:
|
||||
content = ""
|
||||
if self.index_page_content:
|
||||
try:
|
||||
content = self._fetch_page_content(doc.id, page.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch content for page {page.id}: {e}"
|
||||
)
|
||||
yield self._convert_page_to_document(page, content)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list pages for doc {doc.id}: {e}")
|
||||
|
||||
try:
|
||||
tables = self._list_tables(doc.id)
|
||||
for table in tables:
|
||||
table_timestamp = (
|
||||
datetime.fromisoformat(table.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
try:
|
||||
rows = self._list_rows_and_values(doc.id, table.id)
|
||||
|
||||
table_or_rows_updated = start < table_timestamp <= end
|
||||
if not table_or_rows_updated:
|
||||
for row in rows:
|
||||
row_timestamp = (
|
||||
datetime.fromisoformat(row.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
if start < row_timestamp <= end:
|
||||
table_or_rows_updated = True
|
||||
break
|
||||
|
||||
if table_or_rows_updated:
|
||||
yield self._convert_table_with_rows_to_document(
|
||||
table, rows
|
||||
)
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(
|
||||
f"Failed to list rows for table {table.id}: {e}"
|
||||
)
|
||||
if table_timestamp > start and table_timestamp <= end:
|
||||
yield self._convert_table_with_rows_to_document(
|
||||
table, []
|
||||
)
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list tables for doc {doc.id}: {e}")
|
||||
|
||||
return batch_generator(_iter_documents(), self.batch_size)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validates the Coda connector settings calling the 'whoami' endpoint."""
|
||||
try:
|
||||
response = self.coda_client.get("whoami")
|
||||
logger.info(
|
||||
f"Coda connector validated for user: {response.get('name', 'Unknown')}"
|
||||
)
|
||||
|
||||
if self.workspace_id:
|
||||
params = {"workspaceId": self.workspace_id, "limit": "1"}
|
||||
self.coda_client.get("docs", params=params)
|
||||
logger.info(f"Validated access to workspace: {self.workspace_id}")
|
||||
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Coda credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Coda workspace not found or not accessible (HTTP 404). "
|
||||
"Please verify the workspace_id is correct and shared with the integration."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Coda rate-limits being exceeded (HTTP 429). "
|
||||
"Please try again later."
|
||||
)
|
||||
else:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Coda HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Coda settings validation: {exc}"
|
||||
)
|
||||
@@ -387,124 +387,162 @@ class ConfluenceConnector(
|
||||
attachment_docs: list[Document] = []
|
||||
page_url = ""
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
if not self.allow_images:
|
||||
if media_type.startswith("image/"):
|
||||
logger.info(
|
||||
f"Skipping attachment because allow images is False: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
try:
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping attachment because it is not an accepted file type: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
# Attachment document id: use the download URL for stable identity
|
||||
try:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"], self.is_cloud
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Invalid attachment url for id {attachment['id']}, skipping"
|
||||
)
|
||||
logger.debug(f"Error building attachment url: {e}")
|
||||
continue
|
||||
try:
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
if not self.allow_images:
|
||||
if media_type.startswith("image/"):
|
||||
logger.info(
|
||||
f"Skipping attachment because allow images is False: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping attachment because it is not an accepted file type: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
content_text, file_storage_name = response
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
# Attachment document id: use the download URL for stable identity
|
||||
try:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"], self.is_cloud
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Invalid attachment url for id {attachment['id']}, skipping"
|
||||
)
|
||||
logger.debug(f"Error building attachment url: {e}")
|
||||
continue
|
||||
try:
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
sections.append(
|
||||
ImageSection(link=object_url, image_file_id=file_storage_name)
|
||||
content_text, file_storage_name = response
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
sections.append(
|
||||
ImageSection(
|
||||
link=object_url, image_file_id=file_storage_name
|
||||
)
|
||||
)
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
if "space" in attachment:
|
||||
attachment_metadata["space"] = attachment["space"].get(
|
||||
"name", ""
|
||||
)
|
||||
labels: list[str] = []
|
||||
if "metadata" in attachment and "labels" in attachment["metadata"]:
|
||||
for label in attachment["metadata"]["labels"].get(
|
||||
"results", []
|
||||
):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
attachment_metadata["labels"] = labels
|
||||
page_url = page_url or build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
attachment_metadata["parent_page_id"] = page_url
|
||||
attachment_id = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
if "space" in attachment:
|
||||
attachment_metadata["space"] = attachment["space"].get("name", "")
|
||||
labels: list[str] = []
|
||||
if "metadata" in attachment and "labels" in attachment["metadata"]:
|
||||
for label in attachment["metadata"]["labels"].get("results", []):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
attachment_metadata["labels"] = labels
|
||||
page_url = page_url or build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
attachment_metadata["parent_page_id"] = page_url
|
||||
attachment_id = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners = [
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners = [
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
if attachment.get("version")
|
||||
and attachment["version"].get("when")
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if is_atlassian_date_error(e):
|
||||
# propagate error to be caught and retried
|
||||
raise
|
||||
attachment_failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=object_url,
|
||||
document_link=object_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
except HTTPError as e:
|
||||
# If we get a 403 after all retries, the user likely doesn't have permission
|
||||
# to access attachments on this page. Log and skip rather than failing the whole job.
|
||||
if e.response and e.response.status_code == 403:
|
||||
page_title = page.get("title", "unknown")
|
||||
page_id = page.get("id", "unknown")
|
||||
logger.warning(
|
||||
f"Permission denied (403) when fetching attachments for page '{page_title}' "
|
||||
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
|
||||
"Skipping attachments for this page."
|
||||
)
|
||||
# Build the page URL for the failure record
|
||||
try:
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
except Exception:
|
||||
page_url = f"page_id:{page_id}"
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
if attachment.get("version")
|
||||
and attachment["version"].get("when")
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if is_atlassian_date_error(e):
|
||||
# propagate error to be caught and retried
|
||||
raise
|
||||
attachment_failures.append(
|
||||
return [], [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=object_url,
|
||||
document_link=object_url,
|
||||
document_id=page_id,
|
||||
document_link=page_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
|
||||
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise
|
||||
|
||||
return attachment_docs, attachment_failures
|
||||
|
||||
|
||||
@@ -579,13 +579,18 @@ class OnyxConfluence:
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
try:
|
||||
# Only pass params if they're not already in the URL to avoid duplicate
|
||||
# params accumulating. Confluence's _links.next already includes these.
|
||||
params = {}
|
||||
if "body-format=" not in url_suffix:
|
||||
params["body-format"] = "atlas_doc_format"
|
||||
if "expand=" not in url_suffix:
|
||||
params["expand"] = "body.atlas_doc_format"
|
||||
|
||||
raw_response = self.get(
|
||||
path=url_suffix,
|
||||
advanced_mode=True,
|
||||
params={
|
||||
"body-format": "atlas_doc_format",
|
||||
"expand": "body.atlas_doc_format",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.utils.logger import setup_logger
|
||||
HUBSPOT_BASE_URL = "https://app.hubspot.com"
|
||||
HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
@@ -68,6 +68,10 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.slab.connector",
|
||||
class_name="SlabConnector",
|
||||
),
|
||||
DocumentSource.CODA: ConnectorMapping(
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -99,7 +99,9 @@ DEFAULT_HEADERS = {
|
||||
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"
|
||||
),
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
# Brotli decoding has been flaky in brotlicffi/httpx for certain chunked responses;
|
||||
# stick to gzip/deflate to keep connectivity checks stable.
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"Sec-Fetch-Dest": "document",
|
||||
|
||||
@@ -20,6 +20,11 @@ class OptionalSearchSetting(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
An explanation of how the history of messages, tool calls, and docs are stored in the database:
|
||||
|
||||
Messages are grouped by a chat session, a tree structured is used to allow edits and for the
|
||||
user to switch between branches. Each ChatMessage is either a user message of an assistant message.
|
||||
user to switch between branches. Each ChatMessage is either a user message or an assistant message.
|
||||
It should always alternate between the two, System messages, custom agent prompt injections, and
|
||||
reminder messages are injected dynamically after the chat session is loaded into memory. The user
|
||||
and assistant messages are stored in pairs, though it is ok if the user message is stored and the
|
||||
|
||||
@@ -2141,6 +2141,8 @@ class ChatMessage(Base):
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
# True if this assistant message is a clarification question (deep research flow)
|
||||
is_clarification: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
|
||||
@@ -1,16 +1,47 @@
|
||||
# TODO: Notes for potential extensions and future improvements:
|
||||
# 1. Allow tools that aren't search specific tools
|
||||
# 2. Use user provided custom prompts
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_USER_MESSAGES_FOR_CONTEXT = 5
|
||||
MAX_ORCHESTRATOR_CYCLES = 8
|
||||
|
||||
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
@@ -21,8 +52,203 @@ def run_deep_research_llm_loop(
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> None:
|
||||
# Here for lazy load LiteLLM
|
||||
from onyx.llm.litellm_singleton.config import initialize_litellm
|
||||
|
||||
# An approximate limit. In extreme cases it may still fail but this should allow deep research
|
||||
# to work in most cases.
|
||||
if llm.config.max_input_tokens < 25000:
|
||||
raise RuntimeError(
|
||||
"Cannot run Deep Research with an LLM that has less than 25,000 max input tokens"
|
||||
)
|
||||
|
||||
initialize_litellm()
|
||||
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
|
||||
llm_step_result: LlmStepResult | None = None
|
||||
|
||||
# Filter tools to only allow web search, internal search, and open URL
|
||||
allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME}
|
||||
[tool for tool in tools if tool.name in allowed_tool_names]
|
||||
|
||||
#########################################################
|
||||
# CLARIFICATION STEP (optional)
|
||||
#########################################################
|
||||
if not skip_clarification:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
step_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(Packet(turn_index=0, obj=OverallStop(type="stop")))
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
research_plan = llm_step_result.answer
|
||||
|
||||
#########################################################
|
||||
# RESEARCH EXECUTION STEP
|
||||
#########################################################
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT if not is_reasoning_model else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
for cycle in range(MAX_ORCHESTRATOR_CYCLES):
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
turn_index=cycle,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
18
backend/onyx/deep_research/dr_mock_tools.py
Normal file
18
backend/onyx/deep_research/dr_mock_tools.py
Normal file
@@ -0,0 +1,18 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
|
||||
def get_clarification_tool_definitions() -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": GENERATE_PLAN_TOOL_NAME,
|
||||
"description": "No clarification needed, generate a research plan for the user's query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
325
backend/onyx/document_index/interfaces_new.py
Normal file
325
backend/onyx/document_index/interfaces_new.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
# NOTE: "Document" in the naming convention is used to refer to the entire document as represented in Onyx.
|
||||
# What is actually stored in the index is the document chunks. By the terminology of most search engines / vector
|
||||
# databases, the individual objects stored are called documents, but in this case it refers to a chunk.
|
||||
|
||||
# Outside of searching and update capabilities, the document index must also implement the ability to port all of
|
||||
# the documents over to a secondary index. This allows for embedding models to be updated and for porting documents
|
||||
# to happen in the background while the primary index still serves the main traffic.
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Main interfaces - these are what you should inherit from
|
||||
"DocumentIndex",
|
||||
# Data models - used in method signatures
|
||||
"DocumentInsertionRecord",
|
||||
"DocumentSectionRequest",
|
||||
"IndexingMetadata",
|
||||
"MetadataUpdateRequest",
|
||||
# Capability mixins - for custom compositions or type checking
|
||||
"SchemaVerifiable",
|
||||
"Indexable",
|
||||
"Deletable",
|
||||
"Updatable",
|
||||
"IdRetrievalCapable",
|
||||
"HybridCapable",
|
||||
"RandomCapable",
|
||||
]
|
||||
|
||||
|
||||
class DocumentInsertionRecord(BaseModel):
|
||||
"""
|
||||
Result of indexing a document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_id: str
|
||||
already_existed: bool
|
||||
|
||||
|
||||
class DocumentSectionRequest(BaseModel):
|
||||
"""
|
||||
Request for a document section or whole document
|
||||
If no min_chunk_ind is provided it should start at the beginning of the document
|
||||
If no max_chunk_ind is provided it should go to the end of the document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_id: str
|
||||
min_chunk_ind: int | None = None
|
||||
max_chunk_ind: int | None = None
|
||||
|
||||
|
||||
class IndexingMetadata(BaseModel):
|
||||
"""
|
||||
Information about chunk counts for efficient cleaning / updating of document chunks. A common pattern to ensure
|
||||
that no chunks are left over is to delete all of the chunks for a document and then re-index the document. This
|
||||
information allows us to only delete the extra "tail" chunks when the document has gotten shorter.
|
||||
"""
|
||||
|
||||
# The tuple is (old_chunk_cnt, new_chunk_cnt)
|
||||
doc_id_to_chunk_cnt_diff: dict[str, tuple[int, int]]
|
||||
|
||||
|
||||
class MetadataUpdateRequest(BaseModel):
|
||||
"""
|
||||
Updates to the documents that can happen without there being an update to the contents of the document.
|
||||
"""
|
||||
|
||||
document_ids: list[str]
|
||||
# Passed in to help with potential optimizations of the implementation
|
||||
doc_id_to_chunk_cnt: dict[str, int]
|
||||
# For the ones that are None, there is no update required to that field
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
|
||||
|
||||
class SchemaVerifiable(abc.ABC):
|
||||
"""
|
||||
Class must implement document index schema verification. For example, verify that all of the
|
||||
necessary attributes for indexing, querying, filtering, and fields to return from search are
|
||||
all valid in the schema.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_id: int | None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_name = index_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abc.abstractmethod
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code. For certain search
|
||||
engines, the schema needs to be created before indexing can happen. This call should create the schema if it
|
||||
does not exist.
|
||||
|
||||
Parameters:
|
||||
- embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- embedding_precision: Precision of the vector similarity part of the search
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Indexable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to index document chunks
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterator[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index. This is often a batch operation
|
||||
including chunks from multiple documents.
|
||||
|
||||
NOTE: When a document is reindexed/updated here and has gotten shorter, it is important to delete the extra
|
||||
chunks at the end to ensure there are no stale chunks in the index.
|
||||
|
||||
NOTE: The chunks of a document are never separated into separate index() calls. So there is
|
||||
no worry of receiving the first 0 through n chunks in one index call and the next n through
|
||||
m chunks of a document in the next index call.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document index.
|
||||
- indexing_metadata: Information about chunk counts for efficient cleaning / updating
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
when updating, as well as if the document is newly indexed or already existed and
|
||||
just updated
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deletable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to delete document by a given unique document id. Note that the document id is the
|
||||
unique identifier for the document as represented in Onyx, not in the document index.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
db_doc_id: str,
|
||||
*,
|
||||
# Passed in in case it helps the efficiency of the delete implementation
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
Given a single document, hard delete all of the chunks for the document from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_id: document id as represented in Onyx
|
||||
- chunk_count: number of chunks in the document
|
||||
|
||||
Returns:
|
||||
number of chunks deleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Updatable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to update certain attributes of a document without needing to
|
||||
update all of the fields. Specifically, needs to be able to update:
|
||||
- Access Control List
|
||||
- Document-set membership
|
||||
- Boost value (learning from feedback mechanism)
|
||||
- Whether the document is hidden or not, hidden documents are not returned from search
|
||||
- Which Projects the document is a part of
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
None values mean that the field does not need an update.
|
||||
|
||||
Parameters:
|
||||
- update_requests: for a list of document ids in the update request, apply the same updates
|
||||
to all of the documents with those ids. This is for bulk handling efficiency. Many
|
||||
updates are done at the connector level which have many documents for the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdRetrievalCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to retrieve either:
|
||||
- All of the chunks of a document IN ORDER given a document id. Caller assumes it to be in order.
|
||||
- A specific section (continuous set of chunks) for some document.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Fetch chunk(s) based on document id
|
||||
|
||||
NOTE: This is used to reconstruct a full document or an extended (multi-chunk) section
|
||||
of a document. Downstream currently assumes that the chunking does not introduce overlaps
|
||||
between the chunks. If there are overlaps for the chunks, then the reconstructed document
|
||||
or extended section will have duplicate segments.
|
||||
|
||||
NOTE: This should be used after a search call to get more context around returned chunks.
|
||||
There is no filters here since the calling code should not be calling this on arbitrary
|
||||
documents.
|
||||
|
||||
Parameters:
|
||||
- chunk_requests: requests containing the document id and the chunk range to retrieve
|
||||
|
||||
Returns:
|
||||
list of sections from the documents specified
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HybridCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement hybrid (keyword + vector) search functionality
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
|
||||
Parameters:
|
||||
- query: unmodified user query. This may be needed for getting the matching highlighted
|
||||
keywords or for logging purposes
|
||||
- query_embedding: vector representation of the query, must be of the correct
|
||||
dimensionality for the primary index
|
||||
- final_keywords: Final keywords to be used from the query, defaults to query if not set
|
||||
- query_type: Semantic or keyword type query, may use different scoring logic for each
|
||||
- filters: Filters for things like permissions, source type, time, etc.
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
|
||||
Returns:
|
||||
Score ranked (highest first) list of highest matching chunks
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability.
|
||||
This currently is just used for porting the documents to a secondary index."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters | None = None,
|
||||
num_to_retrieve: int = 100,
|
||||
dirty: bool | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DocumentIndex(
|
||||
SchemaVerifiable,
|
||||
Indexable,
|
||||
Updatable,
|
||||
Deletable,
|
||||
HybridCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
A valid document index that can plug into all Onyx flows must implement all of these
|
||||
functionalities.
|
||||
|
||||
As a high level summary, document indices need to be able to
|
||||
- Verify the schema definition is valid
|
||||
- Index new documents
|
||||
- Update specific attributes of existing documents
|
||||
- Delete documents
|
||||
- Run hybrid search
|
||||
- Retrieve document or sections of documents based on document id
|
||||
- Retrieve sets of random documents
|
||||
"""
|
||||
@@ -25,17 +25,17 @@ class SlackEntities(BaseModel):
|
||||
|
||||
# Direct message filtering
|
||||
include_dm: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include user direct messages in search results",
|
||||
)
|
||||
include_group_dm: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include group direct messages (multi-person DMs) in search results",
|
||||
)
|
||||
|
||||
# Private channel filtering
|
||||
include_private_channels: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include private channels in search results (user must have access)",
|
||||
)
|
||||
|
||||
|
||||
@@ -298,17 +298,17 @@ def verify_user_files(
|
||||
|
||||
for file_descriptor in user_files:
|
||||
# Check if this file descriptor has a user_file_id
|
||||
if "user_file_id" in file_descriptor and file_descriptor["user_file_id"]:
|
||||
if file_descriptor.get("user_file_id"):
|
||||
try:
|
||||
user_file_ids.append(UUID(file_descriptor["user_file_id"]))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid user_file_id in file descriptor: {file_descriptor.get('user_file_id')}"
|
||||
f"Invalid user_file_id in file descriptor: {file_descriptor['user_file_id']}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# This is a project file - use the 'id' field which is the file_id
|
||||
if "id" in file_descriptor and file_descriptor["id"]:
|
||||
if file_descriptor.get("id"):
|
||||
project_file_ids.append(file_descriptor["id"])
|
||||
|
||||
# Verify user files (existing logic)
|
||||
|
||||
@@ -80,7 +80,11 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(REDIS_KEY_PREFIX + key, json.dumps(value))
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
|
||||
@@ -9,12 +9,14 @@ from typing import Union
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.app_configs import SEND_USER_METADATA_TO_LLM_PROVIDER
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ReasoningEffort
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.llm_provider_options import AZURE_PROVIDER_NAME
|
||||
@@ -41,6 +43,7 @@ if TYPE_CHECKING:
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
|
||||
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
MAX_LITELLM_USER_ID_LENGTH = 64
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
@@ -70,6 +73,17 @@ def _prompt_as_json(prompt: LanguageModelInput) -> JSON_ro:
|
||||
return cast(JSON_ro, _prompt_to_dicts(prompt))
|
||||
|
||||
|
||||
def _truncate_litellm_user_id(user_id: str) -> str:
|
||||
if len(user_id) <= MAX_LITELLM_USER_ID_LENGTH:
|
||||
return user_id
|
||||
logger.warning(
|
||||
"LLM user id exceeds %d chars (len=%d); truncating for provider compatibility.",
|
||||
MAX_LITELLM_USER_ID_LENGTH,
|
||||
len(user_id),
|
||||
)
|
||||
return user_id[:MAX_LITELLM_USER_ID_LENGTH]
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@@ -233,6 +247,7 @@ class LitellmLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
self._record_call(prompt)
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
@@ -251,6 +266,29 @@ class LitellmLLM(LLM):
|
||||
else:
|
||||
model_provider = self.config.model_provider
|
||||
|
||||
completion_kwargs: dict[str, Any] = self._model_kwargs
|
||||
if SEND_USER_METADATA_TO_LLM_PROVIDER and user_identity:
|
||||
completion_kwargs = dict(self._model_kwargs)
|
||||
|
||||
if user_identity.user_id:
|
||||
completion_kwargs["user"] = _truncate_litellm_user_id(
|
||||
user_identity.user_id
|
||||
)
|
||||
|
||||
if user_identity.session_id:
|
||||
existing_metadata = completion_kwargs.get("metadata")
|
||||
metadata: dict[str, Any] | None
|
||||
if existing_metadata is None:
|
||||
metadata = {}
|
||||
elif isinstance(existing_metadata, dict):
|
||||
metadata = dict(existing_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
if metadata is not None:
|
||||
metadata["session_id"] = user_identity.session_id
|
||||
completion_kwargs["metadata"] = metadata
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
@@ -324,7 +362,7 @@ class LitellmLLM(LLM):
|
||||
else {}
|
||||
),
|
||||
**({self._max_token_param: max_tokens} if max_tokens else {}),
|
||||
**self._model_kwargs,
|
||||
**completion_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -367,6 +405,7 @@ class LitellmLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> ModelResponse:
|
||||
from litellm import ModelResponse as LiteLLMModelResponse
|
||||
|
||||
@@ -384,6 +423,7 @@ class LitellmLLM(LLM):
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -398,6 +438,7 @@ class LitellmLLM(LLM):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Iterator[ModelResponseStream]:
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
from onyx.llm.model_response import from_litellm_model_response_stream
|
||||
@@ -414,6 +455,7 @@ class LitellmLLM(LLM):
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,11 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class LLMUserIdentity(BaseModel):
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
model_provider: str
|
||||
model_name: str
|
||||
@@ -44,6 +49,7 @@ class LLM(abc.ABC):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> "ModelResponse":
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -56,5 +62,6 @@ class LLM(abc.ABC):
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Iterator[ModelResponseStream]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -606,6 +606,56 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_openai_responses_tool_content_type() -> None:
|
||||
"""
|
||||
Patches LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text
|
||||
to use 'input_text' type for tool messages instead of 'output_text'.
|
||||
|
||||
The OpenAI Responses API only accepts 'input_text', 'input_image', and 'input_file'
|
||||
in the function_call_output.output array. The default litellm implementation
|
||||
incorrectly uses 'output_text' for tool messages, causing 400 Bad Request errors.
|
||||
|
||||
See: https://github.com/BerriAI/litellm/issues/17507
|
||||
|
||||
This should be removed once litellm releases a fix for this issue.
|
||||
"""
|
||||
original_method = (
|
||||
LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(
|
||||
original_method,
|
||||
"__name__",
|
||||
"",
|
||||
)
|
||||
== "_patched_convert_content_str_to_input_text"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_convert_content_str_to_input_text(
|
||||
self: Any, content: str, role: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert string content to the appropriate Responses API format.
|
||||
|
||||
For user, system, and tool messages, use 'input_text' type.
|
||||
For assistant messages, use 'output_text' type.
|
||||
|
||||
Tool messages go into function_call_output.output, which only accepts
|
||||
'input_text', 'input_image', and 'input_file' types.
|
||||
"""
|
||||
if role in ("user", "system", "tool"):
|
||||
return {"type": "input_text", "text": content}
|
||||
else:
|
||||
return {"type": "output_text", "text": content}
|
||||
|
||||
_patched_convert_content_str_to_input_text.__name__ = (
|
||||
"_patched_convert_content_str_to_input_text"
|
||||
)
|
||||
LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text = _patched_convert_content_str_to_input_text # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -615,11 +665,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_openai_responses_tool_content_type()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -56,6 +56,15 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
@@ -125,6 +134,12 @@ _IGNORABLE_ANTHROPIC_MODELS = {
|
||||
"claude-instant-1",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
}
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
@@ -134,6 +149,55 @@ VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.5-flash-lite"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
|
||||
def is_obsolete_model(model_name: str, provider: str) -> bool:
|
||||
"""Check if a model is obsolete and should be filtered out.
|
||||
|
||||
Filters models that are 2+ major versions behind or deprecated.
|
||||
This is the single source of truth for obsolete model detection.
|
||||
"""
|
||||
model_lower = model_name.lower()
|
||||
|
||||
# OpenAI obsolete models
|
||||
if provider == "openai":
|
||||
# GPT-3 models are obsolete
|
||||
if "gpt-3" in model_lower:
|
||||
return True
|
||||
# Legacy models
|
||||
deprecated = {
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
"davinci",
|
||||
"curie",
|
||||
"babbage",
|
||||
"ada",
|
||||
}
|
||||
if model_lower in deprecated:
|
||||
return True
|
||||
|
||||
# Anthropic obsolete models
|
||||
if provider == "anthropic":
|
||||
if "claude-2" in model_lower or "claude-instant" in model_lower:
|
||||
return True
|
||||
|
||||
# Vertex AI obsolete models
|
||||
if provider == "vertex_ai":
|
||||
if "gemini-1.0" in model_lower:
|
||||
return True
|
||||
if "palm" in model_lower or "bison" in model_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
@@ -155,22 +219,43 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
|
||||
def get_openai_model_names() -> list[str]:
|
||||
"""Get OpenAI model names dynamically from litellm."""
|
||||
import re
|
||||
import litellm
|
||||
|
||||
# TODO: remove these lists once we have a comprehensive model configuration page
|
||||
# The ideal flow should be: fetch all available models --> filter by type
|
||||
# --> allow user to modify filters and select models based on current context
|
||||
non_chat_model_terms = {
|
||||
"embed",
|
||||
"audio",
|
||||
"tts",
|
||||
"whisper",
|
||||
"dall-e",
|
||||
"image",
|
||||
"moderation",
|
||||
"sora",
|
||||
"container",
|
||||
}
|
||||
deprecated_model_terms = {"babbage", "davinci", "gpt-3.5", "gpt-4-"}
|
||||
excluded_terms = non_chat_model_terms | deprecated_model_terms
|
||||
|
||||
# NOTE: We are explicitly excluding all "timestamped" models
|
||||
# because they are mostly just noise in the admin configuration panel
|
||||
# e.g. gpt-4o-2025-07-16, gpt-3.5-turbo-0613, etc.
|
||||
date_pattern = re.compile(r"-\d{4}")
|
||||
|
||||
def is_valid_model(model: str) -> bool:
|
||||
model_lower = model.lower()
|
||||
return not any(
|
||||
ex in model_lower for ex in excluded_terms
|
||||
) and not date_pattern.search(model)
|
||||
|
||||
return sorted(
|
||||
[
|
||||
# Strip openai/ prefix if present
|
||||
model.replace("openai/", "") if model.startswith("openai/") else model
|
||||
(
|
||||
model.removeprefix("openai/")
|
||||
for model in litellm.open_ai_chat_completion_models
|
||||
if "embed" not in model.lower()
|
||||
and "audio" not in model.lower()
|
||||
and "tts" not in model.lower()
|
||||
and "whisper" not in model.lower()
|
||||
and "dall-e" not in model.lower()
|
||||
and "moderation" not in model.lower()
|
||||
and "sora" not in model.lower() # video generation
|
||||
and "container" not in model.lower() # not a model
|
||||
],
|
||||
if is_valid_model(model)
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -184,6 +269,7 @@ def get_anthropic_model_names() -> list[str]:
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in _IGNORABLE_ANTHROPIC_MODELS
|
||||
and not is_obsolete_model(model, ANTHROPIC_PROVIDER_NAME)
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
@@ -229,6 +315,7 @@ def get_vertexai_model_names() -> list[str]:
|
||||
and "/" not in model # filter out prefixed models like openai/gpt-oss
|
||||
and "search_api" not in model.lower() # not a model
|
||||
and "-maas" not in model.lower() # marketplace models
|
||||
and not is_obsolete_model(model, VERTEXAI_PROVIDER_NAME)
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
@@ -468,18 +555,30 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _get_visible_models_for_provider(provider_name: str) -> set[str]:
|
||||
"""Get the set of models that should be visible by default for a provider."""
|
||||
_PROVIDER_TO_VISIBLE_MODELS: dict[str, set[str]] = {
|
||||
OPENAI_PROVIDER_NAME: OPENAI_VISIBLE_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
}
|
||||
return _PROVIDER_TO_VISIBLE_MODELS.get(provider_name, set())
|
||||
|
||||
|
||||
def fetch_model_configurations_for_provider(
|
||||
provider_name: str,
|
||||
) -> list[ModelConfigurationView]:
|
||||
"""Fetch model configurations for a static provider (OpenAI, Anthropic, Vertex AI).
|
||||
|
||||
Looks up max_input_tokens from LiteLLM's model_cost. If not found, stores None
|
||||
and the runtime will use the fallback (4096).
|
||||
and the runtime will use the fallback (32000).
|
||||
|
||||
Models in the curated visible lists (OPENAI_VISIBLE_MODEL_NAMES, etc.) are
|
||||
marked as is_visible=True by default.
|
||||
"""
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
# No models are marked visible by default - the default model logic
|
||||
# in the frontend/backend will handle making default models visible.
|
||||
visible_models = _get_visible_models_for_provider(provider_name)
|
||||
configs = []
|
||||
for model_name in fetch_models_for_provider(provider_name):
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
@@ -490,7 +589,7 @@ def fetch_model_configurations_for_provider(
|
||||
configs.append(
|
||||
ModelConfigurationView(
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
is_visible=model_name in visible_models,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_supports_image_input(
|
||||
model_name=model_name,
|
||||
|
||||
@@ -2621,6 +2621,28 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-12-11"
|
||||
},
|
||||
"gpt-5.2-pro": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-chat-latest": {
|
||||
"display_name": "GPT 5.2 Chat",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-2025-12-11": {
|
||||
"display_name": "GPT 5.2",
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-12-11"
|
||||
},
|
||||
"gpt-5.2": {
|
||||
"display_name": "GPT 5.2",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.1": {
|
||||
"display_name": "GPT 5.1",
|
||||
"model_vendor": "openai"
|
||||
|
||||
@@ -85,7 +85,15 @@ def litellm_exception_to_error_msg(
|
||||
custom_error_msg_mappings: (
|
||||
dict[str, str] | None
|
||||
) = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
||||
) -> str:
|
||||
) -> tuple[str, str, bool]:
|
||||
"""Convert a LiteLLM exception to a user-friendly error message with classification.
|
||||
|
||||
Returns:
|
||||
tuple: (error_message, error_code, is_retryable)
|
||||
- error_message: User-friendly error description
|
||||
- error_code: Categorized error code for frontend display
|
||||
- is_retryable: Whether the user should try again
|
||||
"""
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from litellm.exceptions import PermissionDeniedError
|
||||
@@ -102,25 +110,37 @@ def litellm_exception_to_error_msg(
|
||||
|
||||
core_exception = _unwrap_nested_exception(e)
|
||||
error_msg = str(core_exception)
|
||||
error_code = "UNKNOWN_ERROR"
|
||||
is_retryable = True
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
return custom_error_msg, "CUSTOM_ERROR", True
|
||||
|
||||
if isinstance(core_exception, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
error_code = "BAD_REQUEST"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, AuthenticationError):
|
||||
error_msg = "Authentication failed: Please check your API key and credentials."
|
||||
error_code = "AUTH_ERROR"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, PermissionDeniedError):
|
||||
error_msg = (
|
||||
"Permission denied: You don't have the necessary permissions for this operation."
|
||||
"Permission denied: You don't have the necessary permissions for this operation. "
|
||||
"Ensure you have access to this model."
|
||||
)
|
||||
error_code = "PERMISSION_DENIED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, NotFoundError):
|
||||
error_msg = "Resource not found: The requested resource doesn't exist."
|
||||
error_code = "NOT_FOUND"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, UnprocessableEntityError):
|
||||
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
|
||||
error_code = "UNPROCESSABLE_ENTITY"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, RateLimitError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
@@ -151,6 +171,8 @@ def litellm_exception_to_error_msg(
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
error_code = "RATE_LIMIT"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, ServiceUnavailableError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
@@ -168,6 +190,8 @@ def litellm_exception_to_error_msg(
|
||||
else:
|
||||
# Generic 503 Service Unavailable
|
||||
error_msg = f"{provider_name} service error: {str(core_exception)}"
|
||||
error_code = "SERVICE_UNAVAILABLE"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, ContextWindowExceededError):
|
||||
error_msg = (
|
||||
"Context window exceeded: Your input is too long for the model to process."
|
||||
@@ -178,29 +202,44 @@ def litellm_exception_to_error_msg(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
error_msg += f"Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}"
|
||||
error_msg += f" Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}."
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unable to get maximum input token for LiteLLM excpetion handling"
|
||||
"Unable to get maximum input token for LiteLLM exception handling"
|
||||
)
|
||||
error_code = "CONTEXT_TOO_LONG"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, ContentPolicyViolationError):
|
||||
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
|
||||
error_code = "CONTENT_POLICY"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, APIConnectionError):
|
||||
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, BudgetExceededError):
|
||||
error_msg = (
|
||||
"Budget exceeded: You've exceeded your allocated budget for API usage."
|
||||
)
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, APIError):
|
||||
error_msg = (
|
||||
"API error: An error occurred while communicating with the API. "
|
||||
f"Details: {str(core_exception)}"
|
||||
)
|
||||
error_code = "API_ERROR"
|
||||
is_retryable = True
|
||||
elif not fallback_to_error_msg:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
return error_msg
|
||||
error_code = "UNKNOWN_ERROR"
|
||||
is_retryable = True
|
||||
|
||||
return error_msg, error_code, is_retryable
|
||||
|
||||
|
||||
def llm_response_to_string(message: ModelResponse) -> str:
|
||||
@@ -514,11 +553,11 @@ def get_max_input_tokens_from_llm_provider(
|
||||
1. Use max_input_tokens from model_configuration (populated from source APIs
|
||||
like OpenRouter, Ollama, or our Bedrock mapping)
|
||||
2. Look up in litellm.model_cost dictionary
|
||||
3. Fall back to GEN_AI_MODEL_FALLBACK_MAX_TOKENS (4096)
|
||||
3. Fall back to GEN_AI_MODEL_FALLBACK_MAX_TOKENS (32000)
|
||||
|
||||
Most dynamic providers (OpenRouter, Ollama) provide context_length via their
|
||||
APIs. Bedrock doesn't expose this, so we parse from model ID suffix (:200k)
|
||||
or use BEDROCK_MODEL_TOKEN_LIMITS mapping. The 4096 fallback is only hit for
|
||||
or use BEDROCK_MODEL_TOKEN_LIMITS mapping. The 32000 fallback is only hit for
|
||||
unknown models not in any of these sources.
|
||||
"""
|
||||
max_input_tokens = None
|
||||
@@ -545,7 +584,7 @@ def get_bedrock_token_limit(model_id: str) -> int:
|
||||
1. Parse from model ID suffix (e.g., ":200k" → 200000)
|
||||
2. Check LiteLLM's model_cost dictionary
|
||||
3. Fall back to our hardcoded BEDROCK_MODEL_TOKEN_LIMITS mapping
|
||||
4. Default to 4096 if not found anywhere
|
||||
4. Default to 32000 if not found anywhere
|
||||
"""
|
||||
from onyx.llm.constants import BEDROCK_MODEL_TOKEN_LIMITS
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ class RedisConnectorDelete:
|
||||
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
@@ -78,7 +79,7 @@ class RedisConnectorDelete:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
|
||||
@@ -43,6 +43,7 @@ class RedisConnectorPermissionSync:
|
||||
PREFIX = "connectordocpermissionsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator
|
||||
@@ -126,7 +127,7 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
@@ -162,7 +163,7 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL)
|
||||
|
||||
def update_db(
|
||||
self,
|
||||
|
||||
@@ -25,6 +25,7 @@ class RedisConnectorExternalGroupSync:
|
||||
PREFIX = "connectorexternalgroupsync"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator
|
||||
@@ -110,7 +111,7 @@ class RedisConnectorExternalGroupSync:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
@@ -147,7 +148,7 @@ class RedisConnectorExternalGroupSync:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
|
||||
@@ -33,6 +33,7 @@ class RedisConnectorPrune:
|
||||
PREFIX = "connectorpruning"
|
||||
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
|
||||
# phase 1 - geneartor task and progress signals
|
||||
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator
|
||||
@@ -115,7 +116,7 @@ class RedisConnectorPrune:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
@@ -148,7 +149,7 @@ class RedisConnectorPrune:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL)
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
|
||||
@@ -7,6 +7,7 @@ class RedisConnectorStop:
|
||||
|
||||
PREFIX = "connectorstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
|
||||
# if this timeout is exceeded, the caller may decide to take more
|
||||
# drastic measures
|
||||
@@ -30,7 +31,7 @@ class RedisConnectorStop:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
self.redis.set(self.fence_key, 0, ex=self.FENCE_TTL)
|
||||
|
||||
@property
|
||||
def timed_out(self) -> bool:
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.redis.redis_object_helper import RedisObjectHelper
|
||||
class RedisDocumentSet(RedisObjectHelper):
|
||||
PREFIX = "documentset"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
@@ -36,7 +37,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload)
|
||||
self.redis.set(self.fence_key, payload, ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
@property
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.utils.variable_functionality import global_version
|
||||
class RedisUserGroup(RedisObjectHelper):
|
||||
PREFIX = "usergroup"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
@@ -40,7 +41,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload)
|
||||
self.redis.set(self.fence_key, payload, ex=self.FENCE_TTL)
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
@property
|
||||
|
||||
@@ -1796,6 +1796,19 @@ def update_mcp_server_with_tools(
|
||||
status_code=400, detail="MCP server has no admin connection config"
|
||||
)
|
||||
|
||||
name_changed = request.name is not None and request.name != mcp_server.name
|
||||
description_changed = (
|
||||
request.description is not None
|
||||
and request.description != mcp_server.description
|
||||
)
|
||||
if name_changed or description_changed:
|
||||
mcp_server = update_mcp_server__no_commit(
|
||||
server_id=mcp_server.id,
|
||||
db_session=db_session,
|
||||
name=request.name if name_changed else None,
|
||||
description=request.description if description_changed else None,
|
||||
)
|
||||
|
||||
selected_names = set(request.selected_tools or [])
|
||||
updated_tools = _sync_tools_for_server(
|
||||
mcp_server,
|
||||
@@ -1807,6 +1820,7 @@ def update_mcp_server_with_tools(
|
||||
|
||||
return MCPServerUpdateResponse(
|
||||
server_id=mcp_server.id,
|
||||
server_name=mcp_server.name,
|
||||
updated_tools=updated_tools,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,6 +134,10 @@ class MCPToolCreateRequest(BaseModel):
|
||||
|
||||
class MCPToolUpdateRequest(BaseModel):
|
||||
server_id: int = Field(..., description="ID of the MCP server")
|
||||
name: Optional[str] = Field(None, description="Updated name of the MCP server")
|
||||
description: Optional[str] = Field(
|
||||
None, description="Updated description of the MCP server"
|
||||
)
|
||||
selected_tools: Optional[List[str]] = Field(
|
||||
None, description="List of selected tool names to create"
|
||||
)
|
||||
@@ -328,6 +332,7 @@ class MCPServerUpdateResponse(BaseModel):
|
||||
"""Response for updating multiple MCP tools"""
|
||||
|
||||
server_id: int
|
||||
server_name: str
|
||||
updated_tools: int
|
||||
|
||||
|
||||
|
||||
@@ -200,6 +200,9 @@ def get_agents_admin_paginated(
|
||||
get_editable: bool = Query(
|
||||
False, description="If true, only returns editable personas."
|
||||
),
|
||||
include_default: bool = Query(
|
||||
True, description="If true, includes builtin/default personas."
|
||||
),
|
||||
) -> PaginatedReturn[PersonaSnapshot]:
|
||||
"""Paginated endpoint for listing agents (formerly personas) (admin view).
|
||||
|
||||
@@ -212,6 +215,7 @@ def get_agents_admin_paginated(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
get_editable=get_editable,
|
||||
include_default=include_default,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
|
||||
@@ -219,6 +223,7 @@ def get_agents_admin_paginated(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=get_editable,
|
||||
include_default=include_default,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
|
||||
@@ -441,6 +446,9 @@ def get_agents_paginated(
|
||||
get_editable: bool = Query(
|
||||
False, description="If true, only returns editable personas."
|
||||
),
|
||||
include_default: bool = Query(
|
||||
True, description="If true, includes builtin/default personas."
|
||||
),
|
||||
) -> PaginatedReturn[MinimalPersonaSnapshot]:
|
||||
"""Paginated endpoint for listing agents available to the user.
|
||||
|
||||
@@ -456,6 +464,7 @@ def get_agents_paginated(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
get_editable=get_editable,
|
||||
include_default=include_default,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
|
||||
@@ -463,6 +472,7 @@ def get_agents_paginated(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=get_editable,
|
||||
include_default=include_default,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ def test_llm_configuration(
|
||||
)
|
||||
|
||||
if error:
|
||||
client_error_msg = litellm_exception_to_error_msg(
|
||||
client_error_msg, _error_code, _is_retryable = litellm_exception_to_error_msg(
|
||||
error, llm, fallback_to_error_msg=True
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=client_error_msg)
|
||||
|
||||
@@ -8,6 +8,9 @@ from pydantic import field_validator
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import litellm_thinks_model_supports_image_input
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import DYNAMIC_LLM_PROVIDERS
|
||||
from onyx.server.manage.llm.utils import extract_vendor_from_model_name
|
||||
from onyx.server.manage.llm.utils import filter_model_configurations
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
|
||||
|
||||
@@ -66,6 +69,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
from onyx.llm.llm_provider_options import get_provider_display_name
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
return cls(
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
@@ -75,11 +79,8 @@ class LLMProviderDescriptor(BaseModel):
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_configurations=list(
|
||||
ModelConfigurationView.from_model(
|
||||
model_configuration, llm_provider_model.provider
|
||||
)
|
||||
for model_configuration in llm_provider_model.model_configurations
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
)
|
||||
|
||||
@@ -138,10 +139,12 @@ class LLMProviderView(LLMProvider):
|
||||
except Exception:
|
||||
personas = []
|
||||
|
||||
provider = llm_provider_model.provider
|
||||
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=llm_provider_model.provider,
|
||||
provider=provider,
|
||||
api_key=llm_provider_model.api_key,
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
@@ -155,11 +158,8 @@ class LLMProviderView(LLMProvider):
|
||||
groups=groups,
|
||||
personas=personas,
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
model_configurations=list(
|
||||
ModelConfigurationView.from_model(
|
||||
model_configuration, llm_provider_model.provider
|
||||
)
|
||||
for model_configuration in llm_provider_model.model_configurations
|
||||
model_configurations=filter_model_configurations(
|
||||
llm_provider_model.model_configurations, provider
|
||||
),
|
||||
)
|
||||
|
||||
@@ -184,54 +184,6 @@ class ModelConfigurationUpsertRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Dynamic providers fetch models directly from source APIs (not LiteLLM)
|
||||
DYNAMIC_LLM_PROVIDERS = {"openrouter", "bedrock", "ollama_chat"}
|
||||
|
||||
|
||||
def _extract_vendor_from_model_name(model_name: str, provider: str) -> str | None:
|
||||
"""Extract vendor from model name for aggregator providers.
|
||||
|
||||
Examples:
|
||||
- OpenRouter: "anthropic/claude-3-5-sonnet" → "Anthropic"
|
||||
- Bedrock: "anthropic.claude-3-5-sonnet-..." → "Anthropic"
|
||||
- Bedrock: "us.anthropic.claude-..." → "Anthropic"
|
||||
- Ollama: "llama3:70b" → "Meta"
|
||||
- Ollama: "qwen2.5:7b" → "Alibaba"
|
||||
"""
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
|
||||
if provider == "openrouter":
|
||||
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
|
||||
if "/" in model_name:
|
||||
vendor_key = model_name.split("/")[0].lower()
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
|
||||
elif provider == "bedrock":
|
||||
# Format: "vendor.model-name" or "region.vendor.model-name"
|
||||
parts = model_name.split(".")
|
||||
if len(parts) >= 2:
|
||||
# Check if first part is a region (us, eu, global, etc.)
|
||||
if parts[0] in ("us", "eu", "global", "ap", "apac"):
|
||||
vendor_key = parts[1].lower() if len(parts) > 2 else parts[0].lower()
|
||||
else:
|
||||
vendor_key = parts[0].lower()
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
|
||||
elif provider == "ollama_chat":
|
||||
# Format: "model-name:tag" e.g., "llama3:70b", "qwen2.5:7b"
|
||||
# Extract base name (before colon)
|
||||
base_name = model_name.split(":")[0].lower()
|
||||
# Match against known model prefixes
|
||||
for prefix, vendor in OLLAMA_MODEL_TO_VENDOR.items():
|
||||
if base_name.startswith(prefix):
|
||||
return vendor
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ModelConfigurationView(BaseModel):
|
||||
name: str
|
||||
is_visible: bool
|
||||
@@ -257,7 +209,7 @@ class ModelConfigurationView(BaseModel):
|
||||
and model_configuration_model.display_name
|
||||
):
|
||||
# Extract vendor from model name for grouping (e.g., "Anthropic", "OpenAI")
|
||||
vendor = _extract_vendor_from_model_name(
|
||||
vendor = extract_vendor_from_model_name(
|
||||
model_configuration_model.name, provider_name
|
||||
)
|
||||
|
||||
@@ -349,7 +301,7 @@ class BedrockModelsRequest(BaseModel):
|
||||
class BedrockFinalModelResponse(BaseModel):
|
||||
name: str # Model ID (e.g., "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
display_name: str # Human-readable name from AWS (e.g., "Claude 3.5 Sonnet v2")
|
||||
max_input_tokens: int # From LiteLLM, our mapping, or default 4096
|
||||
max_input_tokens: int # From LiteLLM, our mapping, or default 32000
|
||||
supports_image_input: bool
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,12 @@ from typing import TypedDict
|
||||
|
||||
from onyx.llm.constants import BEDROCK_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_NAME_MAPPINGS
|
||||
from onyx.llm.constants import OLLAMA_MODEL_TO_VENDOR
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
|
||||
|
||||
# Dynamic providers fetch models directly from source APIs (not LiteLLM)
|
||||
DYNAMIC_LLM_PROVIDERS = {"openrouter", "bedrock", "ollama_chat"}
|
||||
|
||||
|
||||
class ModelMetadata(TypedDict):
|
||||
@@ -235,3 +241,104 @@ def is_reasoning_model(model_id: str, display_name: str) -> bool:
|
||||
"""
|
||||
combined = f"{model_id} {display_name}".lower()
|
||||
return any(pattern in combined for pattern in REASONING_MODEL_PATTERNS)
|
||||
|
||||
|
||||
def extract_base_model_name(model: str) -> str | None:
|
||||
"""Extract base model name by removing date suffixes.
|
||||
|
||||
Returns None if no date suffix was found.
|
||||
"""
|
||||
patterns = [
|
||||
r"-\d{8}$", # -20250929
|
||||
r"-\d{4}-\d{2}-\d{2}$", # -2024-08-06
|
||||
r"@\d{8}$", # @20250219
|
||||
]
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, model):
|
||||
return re.sub(pattern, "", model)
|
||||
return None
|
||||
|
||||
|
||||
def should_filter_as_dated_duplicate(
|
||||
model_name: str, all_model_names: set[str]
|
||||
) -> bool:
|
||||
"""Check if this model is a dated variant and a non-dated version exists."""
|
||||
base = extract_base_model_name(model_name)
|
||||
if base and base in all_model_names:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def filter_model_configurations(
|
||||
model_configurations: list,
|
||||
provider: str,
|
||||
) -> list:
|
||||
"""Filter out obsolete and dated duplicate models from configurations.
|
||||
|
||||
Args:
|
||||
model_configurations: List of ModelConfiguration DB models
|
||||
provider: The provider name (e.g., "openai", "anthropic")
|
||||
|
||||
Returns:
|
||||
List of ModelConfigurationView objects with obsolete/duplicate models removed
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.llm_provider_options import is_obsolete_model
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
|
||||
all_model_names = {mc.name for mc in model_configurations}
|
||||
|
||||
filtered_configs = []
|
||||
for model_configuration in model_configurations:
|
||||
# Skip obsolete models
|
||||
if is_obsolete_model(model_configuration.name, provider):
|
||||
continue
|
||||
# Skip dated duplicates when non-dated version exists
|
||||
if should_filter_as_dated_duplicate(model_configuration.name, all_model_names):
|
||||
continue
|
||||
filtered_configs.append(
|
||||
ModelConfigurationView.from_model(model_configuration, provider)
|
||||
)
|
||||
|
||||
return filtered_configs
|
||||
|
||||
|
||||
def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None:
|
||||
"""Extract vendor from model name for aggregator providers.
|
||||
|
||||
Examples:
|
||||
- OpenRouter: "anthropic/claude-3-5-sonnet" → "Anthropic"
|
||||
- Bedrock: "anthropic.claude-3-5-sonnet-..." → "Anthropic"
|
||||
- Bedrock: "us.anthropic.claude-..." → "Anthropic"
|
||||
- Ollama: "llama3:70b" → "Meta"
|
||||
- Ollama: "qwen2.5:7b" → "Alibaba"
|
||||
"""
|
||||
if provider == "openrouter":
|
||||
# Format: "vendor/model-name" e.g., "anthropic/claude-3-5-sonnet"
|
||||
if "/" in model_name:
|
||||
vendor_key = model_name.split("/")[0].lower()
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
|
||||
elif provider == "bedrock":
|
||||
# Format: "vendor.model-name" or "region.vendor.model-name"
|
||||
parts = model_name.split(".")
|
||||
if len(parts) >= 2:
|
||||
# Check if first part is a region (us, eu, global, etc.)
|
||||
if parts[0] in ("us", "eu", "global", "ap", "apac"):
|
||||
vendor_key = parts[1].lower() if len(parts) > 2 else parts[0].lower()
|
||||
else:
|
||||
vendor_key = parts[0].lower()
|
||||
return PROVIDER_DISPLAY_NAMES.get(vendor_key, vendor_key.title())
|
||||
|
||||
elif provider == "ollama_chat":
|
||||
# Format: "model-name:tag" e.g., "llama3:70b", "qwen2.5:7b"
|
||||
# Extract base name (before colon)
|
||||
base_name = model_name.split(":")[0].lower()
|
||||
# Match against known model prefixes
|
||||
for prefix, vendor in OLLAMA_MODEL_TO_VENDOR.items():
|
||||
if base_name.startswith(prefix):
|
||||
return vendor
|
||||
# Fallback: capitalize the base name as vendor
|
||||
return base_name.split("-")[0].title()
|
||||
|
||||
return None
|
||||
|
||||
@@ -361,7 +361,8 @@ def bulk_invite_users(
|
||||
|
||||
try:
|
||||
for email in emails:
|
||||
email_info = validate_email(email)
|
||||
# Allow syntactically valid emails without DNS deliverability checks; tests use test domains
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
|
||||
@@ -162,36 +162,13 @@ def test_search_provider(
|
||||
status_code=400, detail="Unable to build provider configuration."
|
||||
)
|
||||
|
||||
# Actually test the API key by making a real search call
|
||||
# Run the API client's test_connection method to ensure the connection is valid.
|
||||
try:
|
||||
test_results = provider.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: search returned no results.",
|
||||
)
|
||||
return provider.test_connection()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info(
|
||||
f"Web search provider test succeeded for {request.provider_type.value}."
|
||||
)
|
||||
return {"status": "ok"}
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
@admin_router.get("/content-providers", response_model=list[WebContentProviderView])
|
||||
|
||||
@@ -34,6 +34,9 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
|
||||
DEEP_RESEARCH_PLAN_START = "deep_research_plan_start"
|
||||
DEEP_RESEARCH_PLAN_DELTA = "deep_research_plan_delta"
|
||||
|
||||
|
||||
class BaseObj(BaseModel):
|
||||
type: str = ""
|
||||
@@ -222,6 +225,20 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class DeepResearchPlanStart(BaseObj):
|
||||
type: Literal["deep_research_plan_start"] = (
|
||||
StreamingType.DEEP_RESEARCH_PLAN_START.value
|
||||
)
|
||||
|
||||
|
||||
class DeepResearchPlanDelta(BaseObj):
|
||||
type: Literal["deep_research_plan_delta"] = (
|
||||
StreamingType.DEEP_RESEARCH_PLAN_DELTA.value
|
||||
)
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
"""Packet"""
|
||||
|
||||
# Discriminated union of all possible packet object types
|
||||
@@ -254,6 +271,9 @@ PacketObj = Union[
|
||||
ReasoningDone,
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,9 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# TTL for settings keys - 30 days
|
||||
SETTINGS_TTL = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
def load_settings() -> Settings:
|
||||
kv_store = get_kv_store()
|
||||
@@ -41,7 +44,9 @@ def load_settings() -> Settings:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0")
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0", ex=SETTINGS_TTL
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error and reset to default
|
||||
logger.error(f"Error loading anonymous user setting from Redis: {str(e)}")
|
||||
@@ -66,6 +71,7 @@ def store_settings(settings: Settings) -> None:
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
ex=SETTINGS_TTL,
|
||||
)
|
||||
|
||||
get_kv_store().store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
@@ -106,11 +106,16 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
# To know what citation number to start at for constructing the string to the LLM
|
||||
starting_citation_num: int
|
||||
# This is needed because the LLM won't be able to do a really detailed semantic query well
|
||||
# without help and a specific custom prompt for this
|
||||
original_query: str | None = None
|
||||
message_history: list[ChatMinimalTextMessage] | None = None
|
||||
memories: list[str] | None = None
|
||||
user_info: str | None = None
|
||||
|
||||
# Used for tool calls after the first one but in the same chat turn. The reason for this is that if the initial pass through
|
||||
# the custom flow did not yield good results, we don't want to go through it again. In that case, we defer entirely to the LLM
|
||||
skip_query_expansion: bool = False
|
||||
|
||||
# Number of results to return in the richer object format so that it can be rendered in the UI
|
||||
num_hits: int | None = NUM_RETURNED_HITS
|
||||
# Number of chunks (token approx) to include in the string to the LLM
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
@@ -23,6 +24,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.oauth_config import get_oauth_config
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tools import get_builtin_tool
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
@@ -65,6 +67,12 @@ class CustomToolConfig(BaseModel):
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class SearchToolUsage(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ENABLED = "enabled"
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
"""Helper function to get image generation LLM config based on available providers"""
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
@@ -127,7 +135,7 @@ def construct_tools(
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
disable_internal_search: bool = False,
|
||||
search_usage_forcing_setting: SearchToolUsage = SearchToolUsage.AUTO,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
@@ -146,6 +154,7 @@ def construct_tools(
|
||||
if user and user.oauth_accounts:
|
||||
user_oauth_token = user.oauth_accounts[0].access_token
|
||||
|
||||
added_search_tool = False
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
if allowed_tool_ids is not None and db_tool_model.id not in allowed_tool_ids:
|
||||
@@ -171,7 +180,8 @@ def construct_tools(
|
||||
|
||||
# Handle Internal Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
if disable_internal_search:
|
||||
added_search_tool = True
|
||||
if search_usage_forcing_setting == SearchToolUsage.DISABLED:
|
||||
continue
|
||||
|
||||
if not search_tool_config:
|
||||
@@ -180,7 +190,6 @@ def construct_tools(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# TODO concerning passing the db_session here.
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
db_session=db_session,
|
||||
@@ -371,6 +380,36 @@ def construct_tools(
|
||||
f"Tool '{expected_tool_name}' not found in MCP server '{mcp_server.name}'"
|
||||
)
|
||||
|
||||
if (
|
||||
not added_search_tool
|
||||
and search_usage_forcing_setting == SearchToolUsage.ENABLED
|
||||
):
|
||||
# Get the database tool model for SearchTool
|
||||
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
|
||||
|
||||
# Use the passed-in config if available, otherwise create a new one
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
)
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [search_tool]
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
@@ -376,7 +376,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
try:
|
||||
llm_queries = cast(list[str], llm_kwargs[QUERIES_FIELD])
|
||||
|
||||
# Run semantic and keyword query expansion in parallel
|
||||
# Run semantic and keyword query expansion in parallel (unless skipped)
|
||||
# Use message history, memories, and user info from override_kwargs
|
||||
message_history = (
|
||||
override_kwargs.message_history
|
||||
@@ -386,31 +386,41 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
memories = override_kwargs.memories
|
||||
user_info = override_kwargs.user_info
|
||||
|
||||
# Start timing for query expansion/rephrase
|
||||
query_expansion_start_time = time.time()
|
||||
# Skip query expansion if this is a repeat search call
|
||||
if override_kwargs.skip_query_expansion:
|
||||
logger.debug(
|
||||
"Search tool - Skipping query expansion (repeat search call)"
|
||||
)
|
||||
semantic_query = None
|
||||
keyword_queries: list[str] = []
|
||||
else:
|
||||
# Start timing for query expansion/rephrase
|
||||
query_expansion_start_time = time.time()
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(
|
||||
semantic_query_rephrase,
|
||||
(message_history, self.llm, user_info, memories),
|
||||
),
|
||||
(
|
||||
keyword_query_expansion,
|
||||
(message_history, self.llm, user_info, memories),
|
||||
),
|
||||
]
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(
|
||||
semantic_query_rephrase,
|
||||
(message_history, self.llm, user_info, memories),
|
||||
),
|
||||
(
|
||||
keyword_query_expansion,
|
||||
(message_history, self.llm, user_info, memories),
|
||||
),
|
||||
]
|
||||
|
||||
expansion_results = run_functions_tuples_in_parallel(functions_with_args)
|
||||
expansion_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args
|
||||
)
|
||||
|
||||
# End timing for query expansion/rephrase
|
||||
query_expansion_elapsed = time.time() - query_expansion_start_time
|
||||
logger.debug(
|
||||
f"Search tool - Query expansion/rephrase took {query_expansion_elapsed:.3f} seconds"
|
||||
)
|
||||
semantic_query = expansion_results[0] # str
|
||||
keyword_queries = (
|
||||
expansion_results[1] if expansion_results[1] is not None else []
|
||||
) # list[str]
|
||||
# End timing for query expansion/rephrase
|
||||
query_expansion_elapsed = time.time() - query_expansion_start_time
|
||||
logger.debug(
|
||||
f"Search tool - Query expansion/rephrase took {query_expansion_elapsed:.3f} seconds"
|
||||
)
|
||||
semantic_query = expansion_results[0] # str
|
||||
keyword_queries = (
|
||||
expansion_results[1] if expansion_results[1] is not None else []
|
||||
) # list[str]
|
||||
|
||||
# Prepare queries with their weights and hybrid_alpha settings
|
||||
# Group 1: Keyword queries (use hybrid_alpha=0.2)
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
from exa_py.api import HighlightsContentsOptions
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContent
|
||||
@@ -12,8 +13,11 @@ from onyx.tools.tool_implementations.web_search.models import (
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# TODO can probably break this up
|
||||
class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
@@ -48,6 +52,35 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid Exa API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Exa API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info("Web search provider test succeeded for Exa.")
|
||||
return {"status": "ok"}
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
response = self.exa.get_contents(
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchProvider,
|
||||
@@ -28,7 +29,7 @@ class GooglePSEClient(WebSearchProvider):
|
||||
) -> None:
|
||||
self._api_key = api_key
|
||||
self._search_engine_id = search_engine_id
|
||||
self._num_results = num_results
|
||||
self._num_results = min(num_results, 10) # Google API max is 10
|
||||
self._timeout_seconds = timeout_seconds
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
@@ -119,3 +120,38 @@ class GooglePSEClient(WebSearchProvider):
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
# TODO: I'm not really satisfied with how tailored this is to the particulars of Google PSE.
|
||||
# In particular, I think this might flatten errors that are caused by the API key vs. ones caused
|
||||
# by the search engine ID, or by other factors.
|
||||
# I (David Edelstein) don't feel knowledgeable enough about the return behavior of the Google PSE API
|
||||
# to ensure that we have nicely descriptive and actionable error messages. (Like, what's up with the
|
||||
# thing where 200 status codes can have error messages in the response body?)
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Google PSE validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid Google PSE API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google PSE validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info("Web search provider test succeeded for Google PSE.")
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SearXNGClient(WebSearchProvider):
|
||||
def __init__(
|
||||
self,
|
||||
searxng_base_url: str,
|
||||
num_results: int = 10,
|
||||
) -> None:
|
||||
logger.debug(f"Initializing SearXNGClient with base URL: {searxng_base_url}")
|
||||
self._searxng_base_url = searxng_base_url
|
||||
self._num_results = num_results
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
payload = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
}
|
||||
logger.debug(
|
||||
f"Searching with payload: {payload} to {self._searxng_base_url}/search"
|
||||
)
|
||||
response = requests.post(
|
||||
f"{self._searxng_base_url}/search",
|
||||
data=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
result_list = results.get("results", [])
|
||||
# SearXNG doesn't support limiting results via API parameters,
|
||||
# so we limit client-side after receiving the response
|
||||
limited_results = result_list[: self._num_results]
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result["title"],
|
||||
link=result["url"],
|
||||
snippet=result["content"],
|
||||
)
|
||||
for result in limited_results
|
||||
]
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
logger.debug(f"Testing connection to {self._searxng_base_url}/config")
|
||||
response = requests.get(f"{self._searxng_base_url}/config")
|
||||
logger.debug(f"Response: {response.status_code}, text: {response.text}")
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code
|
||||
logger.debug(
|
||||
f"HTTPError: status_code={status_code}, e.response={e.response.status_code if e.response else None}, error={e}"
|
||||
)
|
||||
if status_code == 429:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"This SearXNG instance does not allow API requests. "
|
||||
"Use a private instance and configure it to allow bots."
|
||||
),
|
||||
) from e
|
||||
elif status_code == 404:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This SearXNG instance was not found. Please check the URL and try again.",
|
||||
) from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"SearXNG connection failed (status {status_code}): {str(e)}",
|
||||
) from e
|
||||
|
||||
# Not a sure way to check if this is a SearXNG instance as opposed to some other website that
|
||||
# happens to have a /config endpoint containing a "brand" key with a "GIT_URL" key with value
|
||||
# "https://github.com/searxng/searxng". I don't think that would happen by coincidence, so I
|
||||
# think this is a good enough check for now. I'm open for suggestions on improvements.
|
||||
config = response.json()
|
||||
if (
|
||||
config.get("brand", {}).get("GIT_URL")
|
||||
!= "https://github.com/searxng/searxng"
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="This does not appear to be a SearXNG instance. Please check the URL and try again.",
|
||||
)
|
||||
|
||||
# Test that JSON mode is enabled by performing a simple search
|
||||
self._test_json_mode()
|
||||
|
||||
logger.info("Web search provider test succeeded for SearXNG.")
|
||||
return {"status": "ok"}
|
||||
|
||||
def _test_json_mode(self) -> None:
|
||||
"""Test that JSON format is enabled in SearXNG settings.
|
||||
|
||||
SearXNG requires JSON format to be explicitly enabled in settings.yml.
|
||||
If it's not enabled, the search endpoint returns a 403.
|
||||
"""
|
||||
try:
|
||||
payload = {
|
||||
"q": "test",
|
||||
"format": "json",
|
||||
}
|
||||
response = requests.post(
|
||||
f"{self._searxng_base_url}/search",
|
||||
data=payload,
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else None
|
||||
if status_code == 403:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Got a 403 response when trying to reach SearXNG. This likely means that "
|
||||
"JSON format is not enabled on this SearXNG instance. "
|
||||
"Please enable JSON format in your SearXNG settings.yml file by adding "
|
||||
"'json' to the 'search.formats' list."
|
||||
),
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to test search on SearXNG instance (status {status_code}): {str(e)}",
|
||||
) from e
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContent
|
||||
@@ -13,8 +14,11 @@ from onyx.tools.tool_implementations.web_search.models import (
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchResult,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SERPER_SEARCH_URL = "https://google.serper.dev/search"
|
||||
SERPER_CONTENTS_URL = "https://scrape.serper.dev"
|
||||
|
||||
@@ -56,6 +60,35 @@ class SerperClient(WebSearchProvider, WebContentProvider):
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API key validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
if (
|
||||
"api" in error_msg.lower()
|
||||
or "key" in error_msg.lower()
|
||||
or "auth" in error_msg.lower()
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid Serper API key: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Serper API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info("Web search provider test succeeded for Serper.")
|
||||
return {"status": "ok"}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
@@ -41,6 +41,10 @@ class WebSearchProvider:
|
||||
def search(self, query: str) -> Sequence[WebSearchResult]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
pass
|
||||
|
||||
|
||||
class WebContentProviderConfig(BaseModel):
|
||||
timeout_seconds: int | None = None
|
||||
|
||||
@@ -13,6 +13,9 @@ from onyx.tools.tool_implementations.web_search.clients.exa_client import (
|
||||
from onyx.tools.tool_implementations.web_search.clients.google_pse_client import (
|
||||
GooglePSEClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.clients.searxng_client import (
|
||||
SearXNGClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.clients.serper_client import (
|
||||
SerperClient,
|
||||
)
|
||||
@@ -55,6 +58,14 @@ def build_search_provider_from_config(
|
||||
num_results=num_results,
|
||||
timeout_seconds=int(config.get("timeout_seconds") or 10),
|
||||
)
|
||||
if provider_type == WebSearchProviderType.SEARXNG:
|
||||
searxng_base_url = config.get("searxng_base_url")
|
||||
if not searxng_base_url:
|
||||
raise ValueError("Please provide a URL for your private SearXNG instance.")
|
||||
return SearXNGClient(
|
||||
searxng_base_url,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
|
||||
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
|
||||
|
||||
@@ -88,6 +88,8 @@ def run_tool_calls(
|
||||
user_info: str | None,
|
||||
citation_mapping: dict[int, str],
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
# Skip query expansion for repeat search tool calls
|
||||
skip_search_query_expansion: bool = False,
|
||||
) -> tuple[
|
||||
list[ToolResponse], dict[int, str]
|
||||
]: # return also the updated citation mapping
|
||||
@@ -136,6 +138,7 @@ def run_tool_calls(
|
||||
message_history=minimal_history,
|
||||
memories=memories,
|
||||
user_info=user_info,
|
||||
skip_query_expansion=skip_search_query_expansion,
|
||||
)
|
||||
|
||||
elif isinstance(tool, WebSearchTool):
|
||||
|
||||
@@ -260,7 +260,7 @@ fastmcp==2.13.3
|
||||
# via onyx
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.15.4
|
||||
filelock==3.20.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onyx
|
||||
@@ -344,14 +344,15 @@ greenlet==3.2.4
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -393,7 +394,7 @@ httpx-sse==0.4.3
|
||||
# via
|
||||
# cohere
|
||||
# mcp
|
||||
hubspot-api-client==8.1.0
|
||||
hubspot-api-client==11.1.0
|
||||
# via onyx
|
||||
huggingface-hub==0.35.3
|
||||
# via
|
||||
@@ -485,7 +486,7 @@ langsmith==0.3.45
|
||||
# langchain-core
|
||||
lazy-imports==1.0.1
|
||||
# via onyx
|
||||
litellm==1.79.0
|
||||
litellm==1.80.10
|
||||
# via onyx
|
||||
locket==1.0.0
|
||||
# via
|
||||
@@ -593,7 +594,7 @@ office365-rest-python-client==2.5.9
|
||||
# via onyx
|
||||
onnxruntime==1.20.1
|
||||
# via magika
|
||||
openai==2.6.1
|
||||
openai==2.8.1
|
||||
# via
|
||||
# exa-py
|
||||
# langfuse
|
||||
@@ -700,7 +701,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.1
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
@@ -897,6 +898,7 @@ requests==2.32.5
|
||||
# google-cloud-bigquery
|
||||
# google-cloud-storage
|
||||
# google-genai
|
||||
# hubspot-api-client
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# jsonschema-path
|
||||
@@ -1088,7 +1090,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
@@ -1141,7 +1142,7 @@ unstructured-client==0.25.4
|
||||
# unstructured
|
||||
uritemplate==4.2.0
|
||||
# via google-api-python-client
|
||||
urllib3==2.6.0
|
||||
urllib3==2.6.1
|
||||
# via
|
||||
# asana
|
||||
# botocore
|
||||
|
||||
@@ -101,12 +101,14 @@ executing==2.2.1
|
||||
faker==37.1.0
|
||||
# via onyx
|
||||
fastapi==0.116.1
|
||||
# via onyx
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
fastavro==1.12.1
|
||||
# via cohere
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.15.4
|
||||
filelock==3.20.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# virtualenv
|
||||
@@ -163,14 +165,15 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -231,7 +234,7 @@ jupyter-core==5.9.1
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
litellm==1.79.0
|
||||
litellm==1.80.10
|
||||
# via onyx
|
||||
manygo==0.2.0
|
||||
# via onyx
|
||||
@@ -262,12 +265,16 @@ numpy==1.26.4
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.1.0
|
||||
onyx-devtools==0.2.0
|
||||
# via onyx
|
||||
openai==2.6.1
|
||||
openai==2.8.1
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
openapi-generator-cli==7.17.0
|
||||
# via
|
||||
# onyx
|
||||
# onyx-devtools
|
||||
packaging==24.2
|
||||
# via
|
||||
# black
|
||||
@@ -317,7 +324,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.1
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -505,7 +512,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mypy
|
||||
@@ -520,7 +526,7 @@ typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via faker
|
||||
urllib3==2.6.0
|
||||
urllib3==2.6.1
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
|
||||
@@ -77,7 +77,7 @@ fastavro==1.12.1
|
||||
# via cohere
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.15.4
|
||||
filelock==3.20.1
|
||||
# via huggingface-hub
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
@@ -132,14 +132,15 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -180,7 +181,7 @@ jsonschema==4.25.1
|
||||
# via litellm
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
litellm==1.79.0
|
||||
litellm==1.80.10
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
@@ -195,7 +196,7 @@ numpy==1.26.4
|
||||
# via
|
||||
# shapely
|
||||
# voyageai
|
||||
openai==2.6.1
|
||||
openai==2.8.1
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
@@ -223,7 +224,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.1
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -328,7 +329,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
@@ -338,7 +338,7 @@ typing-extensions==4.15.0
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic
|
||||
urllib3==2.6.0
|
||||
urllib3==2.6.1
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
|
||||
@@ -112,7 +112,7 @@ fastavro==1.12.1
|
||||
# via cohere
|
||||
fastuuid==0.14.0
|
||||
# via litellm
|
||||
filelock==3.15.4
|
||||
filelock==3.20.1
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
@@ -175,14 +175,15 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.76.0
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.76.0
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -237,7 +238,7 @@ jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
kombu==5.5.4
|
||||
# via celery
|
||||
litellm==1.79.0
|
||||
litellm==1.80.10
|
||||
# via onyx
|
||||
markupsafe==3.0.3
|
||||
# via jinja2
|
||||
@@ -301,7 +302,7 @@ nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||
# via torch
|
||||
openai==2.6.1
|
||||
openai==2.8.1
|
||||
# via
|
||||
# litellm
|
||||
# onyx
|
||||
@@ -341,7 +342,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==6.33.1
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -504,7 +505,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
@@ -520,7 +520,7 @@ tzdata==2025.2
|
||||
# via
|
||||
# kombu
|
||||
# pandas
|
||||
urllib3==2.6.0
|
||||
urllib3==2.6.1
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
|
||||
@@ -1,335 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import NamedTuple
|
||||
from typing import Set
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
|
||||
handlers=[logging.StreamHandler()], # Output logs to console
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LazyImportSettings:
|
||||
"""Settings for which files to ignore when checking for lazy imports."""
|
||||
|
||||
ignore_files: Set[str] | None = None
|
||||
|
||||
|
||||
# Common ignore directories (virtual envs, caches) used across collectors
|
||||
_IGNORE_DIRECTORIES: Set[str] = {".venv", "venv", ".env", "env", "__pycache__"}
|
||||
|
||||
|
||||
# Map of modules to lazy import -> settings for what to ignore
|
||||
_LAZY_IMPORT_MODULES_TO_IGNORE_SETTINGS: Dict[str, LazyImportSettings] = {
|
||||
"google.genai": LazyImportSettings(),
|
||||
"openai": LazyImportSettings(),
|
||||
"markitdown": LazyImportSettings(),
|
||||
"tiktoken": LazyImportSettings(),
|
||||
"transformers": LazyImportSettings(ignore_files={"model_server/main.py"}),
|
||||
"setfit": LazyImportSettings(),
|
||||
"unstructured": LazyImportSettings(),
|
||||
"onyx.llm.litellm_singleton": LazyImportSettings(),
|
||||
"litellm": LazyImportSettings(
|
||||
ignore_files={
|
||||
"onyx/llm/litellm_singleton/__init__.py",
|
||||
"onyx/llm/litellm_singleton/config.py",
|
||||
"onyx/llm/litellm_singleton/monkey_patches.py",
|
||||
}
|
||||
),
|
||||
"nltk": LazyImportSettings(),
|
||||
"trafilatura": LazyImportSettings(),
|
||||
"pypdf": LazyImportSettings(),
|
||||
"unstructured_client": LazyImportSettings(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagerImportResult:
|
||||
"""Result of checking a file for eager imports."""
|
||||
|
||||
violation_lines: List[tuple[int, str]] # (line_number, line_content) tuples
|
||||
violated_modules: Set[str] # modules that were actually violated
|
||||
|
||||
|
||||
def find_eager_imports(
|
||||
file_path: Path, protected_modules: Set[str]
|
||||
) -> EagerImportResult:
|
||||
"""
|
||||
Find eager imports of protected modules in a given file.
|
||||
|
||||
Eager imports are top-level (module-level) imports that happen immediately
|
||||
when the module is loaded, as opposed to lazy imports that happen inside
|
||||
functions only when called.
|
||||
|
||||
Args:
|
||||
file_path: Path to Python file to check
|
||||
protected_modules: Set of module names that should only be imported lazily
|
||||
|
||||
Returns:
|
||||
EagerImportResult containing violations list and violated modules set
|
||||
"""
|
||||
violation_lines = []
|
||||
violated_modules = set()
|
||||
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
lines = content.split("\n")
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip comments and empty lines
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
|
||||
# Only check imports at module level (indentation == 0)
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
if current_indent == 0:
|
||||
# Check for eager imports of protected modules
|
||||
for module in protected_modules:
|
||||
# Pattern 1: import module
|
||||
if re.match(rf"^import\s+{re.escape(module)}(\s|$|\.)", stripped):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
# Pattern 2: from module import ...
|
||||
elif re.match(rf"^from\s+{re.escape(module)}(\s|\.|$)", stripped):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
# Pattern 3: from ... import module (less common but possible)
|
||||
elif re.search(
|
||||
rf"^from\s+[\w.]+\s+import\s+.*\b{re.escape(module)}\b",
|
||||
stripped,
|
||||
):
|
||||
violation_lines.append((line_num, line))
|
||||
violated_modules.add(module)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
|
||||
return EagerImportResult(
|
||||
violation_lines=violation_lines, violated_modules=violated_modules
|
||||
)
|
||||
|
||||
|
||||
def find_python_files(backend_dir: Path) -> List[Path]:
|
||||
"""
|
||||
Find all Python files in the backend directory, excluding test files.
|
||||
|
||||
Args:
|
||||
backend_dir: Path to the backend directory to search
|
||||
|
||||
Returns:
|
||||
List of Python file paths to check
|
||||
"""
|
||||
|
||||
return _collect_python_files([backend_dir], backend_dir)
|
||||
|
||||
|
||||
def _is_valid_python_file(file_path: Path) -> bool:
|
||||
"""
|
||||
Apply shared filtering rules:
|
||||
- Must be a Python file
|
||||
- Exclude tests and common virtualenv/cache directories
|
||||
"""
|
||||
if file_path.suffix != ".py":
|
||||
return False
|
||||
|
||||
path_parts = file_path.parts
|
||||
if (
|
||||
"tests" in path_parts
|
||||
or file_path.name.startswith("test_")
|
||||
or file_path.name.endswith("_test.py")
|
||||
):
|
||||
return False
|
||||
|
||||
if any(ignored_dir in path_parts for ignored_dir in _IGNORE_DIRECTORIES):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _collect_python_files(start_points: List[Path], backend_dir: Path) -> List[Path]:
|
||||
"""
|
||||
Given a list of directories/files, collect Python files that pass shared filters.
|
||||
Constrains collection to within backend_dir.
|
||||
"""
|
||||
collected: List[Path] = []
|
||||
backend_real = backend_dir.resolve()
|
||||
|
||||
for p in start_points:
|
||||
try:
|
||||
p = p.resolve()
|
||||
except Exception:
|
||||
# If resolve fails, skip the path
|
||||
continue
|
||||
|
||||
try:
|
||||
_ = p.relative_to(backend_real)
|
||||
except Exception:
|
||||
# Skip anything outside backend directory to mirror pre-commit filter
|
||||
logger.debug(f"Skipping path outside backend directory: {p}")
|
||||
continue
|
||||
|
||||
if p.is_dir():
|
||||
for file_path in p.glob("**/*.py"):
|
||||
if _is_valid_python_file(file_path):
|
||||
collected.append(file_path)
|
||||
else:
|
||||
if _is_valid_python_file(p):
|
||||
collected.append(p)
|
||||
|
||||
return collected
|
||||
|
||||
|
||||
def should_check_file_for_module(
|
||||
file_path: Path, backend_dir: Path, settings: LazyImportSettings
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a file should be checked for a specific module's imports.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to check
|
||||
backend_dir: Path to the backend directory
|
||||
settings: Settings containing files to ignore for this module
|
||||
|
||||
Returns:
|
||||
True if the file should be checked, False if it should be ignored
|
||||
"""
|
||||
if not settings.ignore_files:
|
||||
# Empty set means check everywhere
|
||||
return True
|
||||
|
||||
# Get relative path from backend directory
|
||||
rel_path = file_path.relative_to(backend_dir)
|
||||
rel_path_str = rel_path.as_posix()
|
||||
|
||||
return rel_path_str not in settings.ignore_files
|
||||
|
||||
|
||||
def _collect_python_files_from_args(
|
||||
provided_paths: List[str], backend_dir: Path
|
||||
) -> List[Path]:
|
||||
"""
|
||||
From a list of provided file or directory paths, collect Python files to check.
|
||||
Only files under the backend directory are considered. Test files and venv dirs
|
||||
are excluded using the same rules as find_python_files.
|
||||
"""
|
||||
if not provided_paths:
|
||||
return []
|
||||
|
||||
normalized: List[Path] = []
|
||||
for raw in provided_paths:
|
||||
p = Path(raw)
|
||||
if not p.exists():
|
||||
logger.debug(f"Ignoring non-existent path: {raw}")
|
||||
continue
|
||||
normalized.append(p)
|
||||
|
||||
return _collect_python_files(normalized, backend_dir)
|
||||
|
||||
|
||||
class Args(NamedTuple):
|
||||
paths: List[str]
|
||||
|
||||
|
||||
def _parse_args() -> Args:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Check that specified modules are only lazily imported. "
|
||||
"Optionally provide files or directories to limit the check; "
|
||||
"if none are provided, all backend Python files are scanned."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"paths",
|
||||
nargs="*",
|
||||
help="Optional file or directory paths to check (relative to repo root).",
|
||||
)
|
||||
parsed = parser.parse_args()
|
||||
return Args(paths=list(parsed.paths))
|
||||
|
||||
|
||||
def main(
|
||||
modules_to_lazy_import: Dict[str, LazyImportSettings],
|
||||
provided_paths: List[str] | None = None,
|
||||
) -> None:
|
||||
backend_dir = Path(__file__).parent.parent # Go up from scripts/ to backend/
|
||||
|
||||
logger.info(
|
||||
f"Checking for direct imports of lazy modules: {', '.join(modules_to_lazy_import.keys())}"
|
||||
)
|
||||
|
||||
# Determine Python files to check
|
||||
if provided_paths:
|
||||
target_python_files = _collect_python_files_from_args(
|
||||
provided_paths, backend_dir
|
||||
)
|
||||
if not target_python_files:
|
||||
logger.info("No matching Python files to check based on provided paths.")
|
||||
return
|
||||
else:
|
||||
target_python_files = find_python_files(backend_dir)
|
||||
|
||||
violations_found = False
|
||||
all_violated_modules = set()
|
||||
|
||||
# Check each Python file for each module with its specific ignore directories
|
||||
for file_path in target_python_files:
|
||||
# Determine which modules should be checked for this file
|
||||
modules_to_check = set()
|
||||
for module_name, settings in modules_to_lazy_import.items():
|
||||
if should_check_file_for_module(file_path, backend_dir, settings):
|
||||
modules_to_check.add(module_name)
|
||||
|
||||
if not modules_to_check:
|
||||
# This file is ignored for all modules
|
||||
continue
|
||||
|
||||
result = find_eager_imports(file_path, modules_to_check)
|
||||
|
||||
if result.violation_lines:
|
||||
violations_found = True
|
||||
all_violated_modules.update(result.violated_modules)
|
||||
rel_path = file_path.relative_to(backend_dir)
|
||||
logger.error(f"\n❌ Eager import violations found in {rel_path}:")
|
||||
|
||||
for line_num, line in result.violation_lines:
|
||||
logger.error(f" Line {line_num}: {line.strip()}")
|
||||
|
||||
# Suggest fix only for violated modules
|
||||
if result.violated_modules:
|
||||
logger.error(
|
||||
f" 💡 You must lazy import {', '.join(sorted(result.violated_modules))} within functions when needed"
|
||||
)
|
||||
|
||||
if violations_found:
|
||||
violated_modules_str = ", ".join(sorted(all_violated_modules))
|
||||
raise RuntimeError(
|
||||
f"Found eager imports of {violated_modules_str}. You must import them only when needed."
|
||||
)
|
||||
else:
|
||||
logger.info("✅ All lazy modules are properly imported!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
args = _parse_args()
|
||||
main(_LAZY_IMPORT_MODULES_TO_IGNORE_SETTINGS, provided_paths=args.paths)
|
||||
sys.exit(0)
|
||||
except RuntimeError:
|
||||
sys.exit(1)
|
||||
230
backend/scripts/dump/README.md
Normal file
230
backend/scripts/dump/README.md
Normal file
@@ -0,0 +1,230 @@
|
||||
# Onyx Data Backup & Restore Scripts
|
||||
|
||||
Scripts for backing up and restoring PostgreSQL, Vespa, and MinIO data from an Onyx deployment.
|
||||
|
||||
## Overview
|
||||
|
||||
Two backup modes are supported:
|
||||
|
||||
| Mode | Description | Pros | Cons |
|
||||
|------|-------------|------|------|
|
||||
| **volume** | Exports Docker volumes directly | Fast, complete, preserves everything | Services must be stopped for consistency |
|
||||
| **api** | Uses pg_dump and Vespa REST API | Services can stay running, more portable | Slower for large datasets |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Backup (from a running instance)
|
||||
|
||||
```bash
|
||||
# Full backup using volume mode (recommended for complete backups)
|
||||
# Note: For consistency, stop services first
|
||||
docker compose -f deployment/docker_compose/docker-compose.yml stop
|
||||
./scripts/dump_data.sh --mode volume --output ./backups
|
||||
docker compose -f deployment/docker_compose/docker-compose.yml start
|
||||
|
||||
# Or use API mode (services can stay running)
|
||||
./scripts/dump_data.sh --mode api --output ./backups
|
||||
```
|
||||
|
||||
### Restore (to a local instance)
|
||||
|
||||
```bash
|
||||
# Restore from latest backup
|
||||
./scripts/restore_data.sh --input ./backups/latest
|
||||
|
||||
# Restore from specific backup
|
||||
./scripts/restore_data.sh --input ./backups/20240115_120000
|
||||
|
||||
# Force restore without confirmation
|
||||
./scripts/restore_data.sh --input ./backups/latest --force
|
||||
```
|
||||
|
||||
## Detailed Usage
|
||||
|
||||
### dump_data.sh
|
||||
|
||||
```
|
||||
Usage: ./scripts/dump_data.sh [OPTIONS]
|
||||
|
||||
Options:
|
||||
--mode <volume|api> Backup mode (default: volume)
|
||||
--output <dir> Output directory (default: ./onyx_backup)
|
||||
--project <name> Docker Compose project name (default: onyx)
|
||||
--postgres-only Only backup PostgreSQL
|
||||
--vespa-only Only backup Vespa
|
||||
--minio-only Only backup MinIO
|
||||
--no-minio Skip MinIO backup
|
||||
--help Show help message
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```bash
|
||||
# Default volume backup
|
||||
./scripts/dump_data.sh
|
||||
|
||||
# API-based backup
|
||||
./scripts/dump_data.sh --mode api
|
||||
|
||||
# Only backup PostgreSQL
|
||||
./scripts/dump_data.sh --postgres-only --mode api
|
||||
|
||||
# Custom output directory
|
||||
./scripts/dump_data.sh --output /mnt/backups/onyx
|
||||
|
||||
# Different project name (if using custom docker compose project)
|
||||
./scripts/dump_data.sh --project my-onyx-instance
|
||||
```
|
||||
|
||||
### restore_data.sh
|
||||
|
||||
```
|
||||
Usage: ./scripts/restore_data.sh [OPTIONS]
|
||||
|
||||
Options:
|
||||
--input <dir> Backup directory (required)
|
||||
--project <name> Docker Compose project name (default: onyx)
|
||||
--postgres-only Only restore PostgreSQL
|
||||
--vespa-only Only restore Vespa
|
||||
--minio-only Only restore MinIO
|
||||
--no-minio Skip MinIO restore
|
||||
--force Skip confirmation prompts
|
||||
--help Show help message
|
||||
```
|
||||
|
||||
**Examples:**
|
||||
|
||||
```bash
|
||||
# Restore all components
|
||||
./scripts/restore_data.sh --input ./onyx_backup/latest
|
||||
|
||||
# Restore only PostgreSQL
|
||||
./scripts/restore_data.sh --input ./onyx_backup/latest --postgres-only
|
||||
|
||||
# Non-interactive restore
|
||||
./scripts/restore_data.sh --input ./onyx_backup/latest --force
|
||||
```
|
||||
|
||||
## Backup Directory Structure
|
||||
|
||||
After running a backup, the output directory contains:
|
||||
|
||||
```
|
||||
onyx_backup/
|
||||
├── 20240115_120000/ # Timestamp-named backup
|
||||
│ ├── metadata.json # Backup metadata
|
||||
│ ├── postgres_volume.tar.gz # PostgreSQL data (volume mode)
|
||||
│ ├── postgres_dump.backup # PostgreSQL dump (api mode)
|
||||
│ ├── vespa_volume.tar.gz # Vespa data (volume mode)
|
||||
│ ├── vespa_documents.jsonl # Vespa documents (api mode)
|
||||
│ ├── minio_volume.tar.gz # MinIO data (volume mode)
|
||||
│ └── minio_data.tar.gz # MinIO data (api mode)
|
||||
└── latest -> 20240115_120000 # Symlink to latest backup
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
You can customize behavior with environment variables:
|
||||
|
||||
```bash
|
||||
# PostgreSQL settings
|
||||
export POSTGRES_USER=postgres
|
||||
export POSTGRES_PASSWORD=password
|
||||
export POSTGRES_DB=postgres
|
||||
export POSTGRES_PORT=5432
|
||||
|
||||
# Vespa settings
|
||||
export VESPA_HOST=localhost
|
||||
export VESPA_PORT=8081
|
||||
export VESPA_INDEX=danswer_index
|
||||
```
|
||||
|
||||
## Typical Workflows
|
||||
|
||||
### Migrate to a new server
|
||||
|
||||
```bash
|
||||
# On source server
|
||||
./scripts/dump_data.sh --mode volume --output ./migration_backup
|
||||
tar czf onyx_backup.tar.gz ./migration_backup/latest
|
||||
|
||||
# Transfer to new server
|
||||
scp onyx_backup.tar.gz newserver:/opt/onyx/
|
||||
|
||||
# On new server
|
||||
cd /opt/onyx
|
||||
tar xzf onyx_backup.tar.gz
|
||||
./scripts/restore_data.sh --input ./migration_backup/latest --force
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
### Create a development copy from production
|
||||
|
||||
```bash
|
||||
# On production (use API mode to avoid downtime)
|
||||
./scripts/dump_data.sh --mode api --output ./prod_backup
|
||||
|
||||
# Copy to dev machine
|
||||
rsync -avz ./prod_backup/latest dev-machine:/home/dev/onyx_backup/
|
||||
|
||||
# On dev machine
|
||||
./scripts/restore_data.sh --input /home/dev/onyx_backup --force
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
```
|
||||
|
||||
### Scheduled backups (cron)
|
||||
|
||||
```bash
|
||||
# Add to crontab: crontab -e
|
||||
# Daily backup at 2 AM
|
||||
0 2 * * * cd /opt/onyx && ./scripts/dump_data.sh --mode api --output /backups/onyx >> /var/log/onyx-backup.log 2>&1
|
||||
|
||||
# Weekly cleanup (keep last 7 days)
|
||||
0 3 * * 0 find /backups/onyx -maxdepth 1 -type d -mtime +7 -exec rm -rf {} \;
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Volume not found" error
|
||||
|
||||
Ensure the Docker Compose project name matches:
|
||||
```bash
|
||||
docker volume ls | grep db_volume
|
||||
# If it shows "myproject_db_volume", use --project myproject
|
||||
```
|
||||
|
||||
### "Container not running" error (API mode)
|
||||
|
||||
Start the required services:
|
||||
```bash
|
||||
cd deployment/docker_compose
|
||||
docker compose up -d relational_db index minio
|
||||
```
|
||||
|
||||
### Vespa restore fails with "not ready"
|
||||
|
||||
Vespa takes time to initialize. Wait and retry:
|
||||
```bash
|
||||
# Check Vespa health
|
||||
curl http://localhost:8081/state/v1/health
|
||||
```
|
||||
|
||||
### PostgreSQL restore shows warnings
|
||||
|
||||
`pg_restore` often shows warnings about objects that don't exist (when using `--clean`). These are usually safe to ignore if the restore completes.
|
||||
|
||||
## Alternative: Python Script
|
||||
|
||||
For more control, you can also use the existing Python script:
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# Save state
|
||||
python -m scripts.save_load_state --save --checkpoint_dir ../onyx_checkpoint
|
||||
|
||||
# Load state
|
||||
python -m scripts.save_load_state --load --checkpoint_dir ../onyx_checkpoint
|
||||
```
|
||||
|
||||
See `backend/scripts/save_load_state.py` for the Python implementation.
|
||||
478
backend/scripts/dump/dump_data.sh
Executable file
478
backend/scripts/dump/dump_data.sh
Executable file
@@ -0,0 +1,478 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# Onyx Data Dump Script
|
||||
# =============================================================================
|
||||
# This script creates a backup of PostgreSQL, Vespa, and MinIO data.
|
||||
#
|
||||
# Two modes available:
|
||||
# - volume: Exports Docker volumes directly (faster, complete backup)
|
||||
# - api: Uses pg_dump and Vespa API (more portable)
|
||||
#
|
||||
# Usage:
|
||||
# ./dump_data.sh [OPTIONS]
|
||||
#
|
||||
# Options:
|
||||
# --mode <volume|api> Backup mode (default: volume)
|
||||
# --output <dir> Output directory (default: ./onyx_backup)
|
||||
# --project <name> Docker Compose project name (default: onyx)
|
||||
# --volume-prefix <name> Volume name prefix (default: same as project name)
|
||||
# --compose-dir <dir> Docker Compose directory (for service management)
|
||||
# --postgres-only Only backup PostgreSQL
|
||||
# --vespa-only Only backup Vespa
|
||||
# --minio-only Only backup MinIO
|
||||
# --no-minio Skip MinIO backup
|
||||
# --no-restart Don't restart services after backup (volume mode)
|
||||
# --help Show this help message
|
||||
#
|
||||
# Examples:
|
||||
# ./dump_data.sh # Full volume backup
|
||||
# ./dump_data.sh --mode api # API-based backup
|
||||
# ./dump_data.sh --output /tmp/backup # Custom output directory
|
||||
# ./dump_data.sh --postgres-only --mode api # Only PostgreSQL via pg_dump
|
||||
# ./dump_data.sh --volume-prefix myprefix # Use custom volume prefix
|
||||
# =============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Default configuration
|
||||
MODE="volume"
|
||||
OUTPUT_DIR="./onyx_backup"
|
||||
PROJECT_NAME="onyx"
|
||||
VOLUME_PREFIX="" # Will default to PROJECT_NAME if not set
|
||||
COMPOSE_DIR="" # Docker Compose directory for service management
|
||||
BACKUP_POSTGRES=true
|
||||
BACKUP_VESPA=true
|
||||
BACKUP_MINIO=true
|
||||
NO_RESTART=false
|
||||
|
||||
# PostgreSQL defaults
|
||||
POSTGRES_USER="${POSTGRES_USER:-postgres}"
|
||||
POSTGRES_PASSWORD="${POSTGRES_PASSWORD:-password}"
|
||||
POSTGRES_DB="${POSTGRES_DB:-postgres}"
|
||||
POSTGRES_PORT="${POSTGRES_PORT:-5432}"
|
||||
|
||||
# Vespa defaults
|
||||
VESPA_HOST="${VESPA_HOST:-localhost}"
|
||||
VESPA_PORT="${VESPA_PORT:-8081}"
|
||||
VESPA_INDEX="${VESPA_INDEX:-danswer_index}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||
}
|
||||
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
show_help() {
|
||||
head -35 "$0" | tail -32
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--mode)
|
||||
MODE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--output)
|
||||
OUTPUT_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--project)
|
||||
PROJECT_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--volume-prefix)
|
||||
VOLUME_PREFIX="$2"
|
||||
shift 2
|
||||
;;
|
||||
--compose-dir)
|
||||
COMPOSE_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--no-restart)
|
||||
NO_RESTART=true
|
||||
shift
|
||||
;;
|
||||
--postgres-only)
|
||||
BACKUP_POSTGRES=true
|
||||
BACKUP_VESPA=false
|
||||
BACKUP_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--vespa-only)
|
||||
BACKUP_POSTGRES=false
|
||||
BACKUP_VESPA=true
|
||||
BACKUP_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--minio-only)
|
||||
BACKUP_POSTGRES=false
|
||||
BACKUP_VESPA=false
|
||||
BACKUP_MINIO=true
|
||||
shift
|
||||
;;
|
||||
--no-minio)
|
||||
BACKUP_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--help)
|
||||
show_help
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate mode
|
||||
if [[ "$MODE" != "volume" && "$MODE" != "api" ]]; then
|
||||
log_error "Invalid mode: $MODE. Use 'volume' or 'api'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set VOLUME_PREFIX to PROJECT_NAME if not specified
|
||||
if [[ -z "$VOLUME_PREFIX" ]]; then
|
||||
VOLUME_PREFIX="$PROJECT_NAME"
|
||||
fi
|
||||
|
||||
# Create output directory with timestamp
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
BACKUP_DIR="${OUTPUT_DIR}/${TIMESTAMP}"
|
||||
mkdir -p "$BACKUP_DIR"
|
||||
|
||||
log_info "Starting Onyx data backup..."
|
||||
log_info "Mode: $MODE"
|
||||
log_info "Output directory: $BACKUP_DIR"
|
||||
log_info "Project name: $PROJECT_NAME"
|
||||
log_info "Volume prefix: $VOLUME_PREFIX"
|
||||
|
||||
# Get container names
|
||||
POSTGRES_CONTAINER="${PROJECT_NAME}-relational_db-1"
|
||||
VESPA_CONTAINER="${PROJECT_NAME}-index-1"
|
||||
MINIO_CONTAINER="${PROJECT_NAME}-minio-1"
|
||||
|
||||
# Track which services were stopped
|
||||
STOPPED_SERVICES=()
|
||||
|
||||
# =============================================================================
|
||||
# Service management functions
|
||||
# =============================================================================
|
||||
|
||||
stop_service() {
|
||||
local service=$1
|
||||
local container="${PROJECT_NAME}-${service}-1"
|
||||
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
log_info "Stopping ${service}..."
|
||||
if [[ -n "$COMPOSE_DIR" ]]; then
|
||||
docker compose -p "$PROJECT_NAME" -f "${COMPOSE_DIR}/docker-compose.yml" stop "$service" 2>/dev/null || \
|
||||
docker stop "$container"
|
||||
else
|
||||
docker stop "$container"
|
||||
fi
|
||||
STOPPED_SERVICES+=("$service")
|
||||
fi
|
||||
}
|
||||
|
||||
start_services() {
|
||||
if [[ ${#STOPPED_SERVICES[@]} -eq 0 ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Restarting services: ${STOPPED_SERVICES[*]}"
|
||||
|
||||
if [[ -n "$COMPOSE_DIR" ]]; then
|
||||
docker compose -p "$PROJECT_NAME" -f "${COMPOSE_DIR}/docker-compose.yml" start "${STOPPED_SERVICES[@]}" 2>/dev/null || {
|
||||
# Fallback to starting containers directly
|
||||
for service in "${STOPPED_SERVICES[@]}"; do
|
||||
docker start "${PROJECT_NAME}-${service}-1" 2>/dev/null || true
|
||||
done
|
||||
}
|
||||
else
|
||||
for service in "${STOPPED_SERVICES[@]}"; do
|
||||
docker start "${PROJECT_NAME}-${service}-1" 2>/dev/null || true
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Volume-based backup functions
|
||||
# =============================================================================
|
||||
|
||||
backup_postgres_volume() {
|
||||
log_info "Backing up PostgreSQL volume..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_db_volume"
|
||||
|
||||
# Check if volume exists
|
||||
if ! docker volume inspect "$volume_name" &>/dev/null; then
|
||||
log_error "PostgreSQL volume '$volume_name' not found"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Export volume to tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/source:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
alpine tar czf /backup/postgres_volume.tar.gz -C /source .
|
||||
|
||||
log_success "PostgreSQL volume backed up to postgres_volume.tar.gz"
|
||||
}
|
||||
|
||||
backup_vespa_volume() {
|
||||
log_info "Backing up Vespa volume..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_vespa_volume"
|
||||
|
||||
# Check if volume exists
|
||||
if ! docker volume inspect "$volume_name" &>/dev/null; then
|
||||
log_error "Vespa volume '$volume_name' not found"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Export volume to tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/source:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
alpine tar czf /backup/vespa_volume.tar.gz -C /source .
|
||||
|
||||
log_success "Vespa volume backed up to vespa_volume.tar.gz"
|
||||
}
|
||||
|
||||
backup_minio_volume() {
|
||||
log_info "Backing up MinIO volume..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_minio_data"
|
||||
|
||||
# Check if volume exists
|
||||
if ! docker volume inspect "$volume_name" &>/dev/null; then
|
||||
log_error "MinIO volume '$volume_name' not found"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Export volume to tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/source:ro" \
|
||||
-v "${BACKUP_DIR}:/backup" \
|
||||
alpine tar czf /backup/minio_volume.tar.gz -C /source .
|
||||
|
||||
log_success "MinIO volume backed up to minio_volume.tar.gz"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# API-based backup functions
|
||||
# =============================================================================
|
||||
|
||||
backup_postgres_api() {
|
||||
log_info "Backing up PostgreSQL via pg_dump..."
|
||||
|
||||
# Check if container is running
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^${POSTGRES_CONTAINER}$"; then
|
||||
log_error "PostgreSQL container '$POSTGRES_CONTAINER' is not running"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Create dump using pg_dump inside container
|
||||
docker exec "$POSTGRES_CONTAINER" \
|
||||
pg_dump -U "$POSTGRES_USER" -F c -b -v "$POSTGRES_DB" \
|
||||
> "${BACKUP_DIR}/postgres_dump.backup"
|
||||
|
||||
log_success "PostgreSQL backed up to postgres_dump.backup"
|
||||
}
|
||||
|
||||
backup_vespa_api() {
|
||||
log_info "Backing up Vespa via API..."
|
||||
|
||||
local endpoint="http://${VESPA_HOST}:${VESPA_PORT}/document/v1/default/${VESPA_INDEX}/docid"
|
||||
local output_file="${BACKUP_DIR}/vespa_documents.jsonl"
|
||||
local continuation=""
|
||||
local total_docs=0
|
||||
|
||||
# Check if Vespa is accessible
|
||||
if ! curl -s -o /dev/null -w "%{http_code}" "$endpoint" | grep -q "200\|404"; then
|
||||
# Try via container if localhost doesn't work
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${VESPA_CONTAINER}$"; then
|
||||
log_warning "Vespa not accessible on $VESPA_HOST:$VESPA_PORT, trying via container..."
|
||||
endpoint="http://localhost:8081/document/v1/default/${VESPA_INDEX}/docid"
|
||||
else
|
||||
log_error "Cannot connect to Vespa at $endpoint"
|
||||
return 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# Clear output file
|
||||
> "$output_file"
|
||||
|
||||
# Fetch documents with pagination
|
||||
while true; do
|
||||
local url="$endpoint"
|
||||
if [[ -n "$continuation" ]]; then
|
||||
url="${endpoint}?continuation=${continuation}"
|
||||
fi
|
||||
|
||||
local response
|
||||
response=$(curl -s "$url")
|
||||
|
||||
# Extract continuation token
|
||||
continuation=$(echo "$response" | jq -r '.continuation // empty')
|
||||
|
||||
# Extract and save documents
|
||||
local docs
|
||||
docs=$(echo "$response" | jq -c '.documents[]? | {update: .id, create: true, fields: .fields}')
|
||||
|
||||
if [[ -n "$docs" ]]; then
|
||||
echo "$docs" >> "$output_file"
|
||||
local count
|
||||
count=$(echo "$docs" | wc -l)
|
||||
total_docs=$((total_docs + count))
|
||||
log_info " Fetched $total_docs documents so far..."
|
||||
fi
|
||||
|
||||
# Check if we're done
|
||||
if [[ -z "$continuation" ]]; then
|
||||
break
|
||||
fi
|
||||
done
|
||||
|
||||
log_success "Vespa backed up to vespa_documents.jsonl ($total_docs documents)"
|
||||
}
|
||||
|
||||
backup_minio_api() {
|
||||
log_info "Backing up MinIO data..."
|
||||
|
||||
local minio_dir="${BACKUP_DIR}/minio_data"
|
||||
mkdir -p "$minio_dir"
|
||||
|
||||
# Check if mc (MinIO client) is available
|
||||
if command -v mc &>/dev/null; then
|
||||
# Configure mc alias for local minio
|
||||
mc alias set onyx-backup http://localhost:9004 minioadmin minioadmin 2>/dev/null || true
|
||||
|
||||
# Mirror all buckets
|
||||
mc mirror onyx-backup/ "$minio_dir/" 2>/dev/null || {
|
||||
log_warning "mc mirror failed, falling back to volume backup"
|
||||
backup_minio_volume
|
||||
return
|
||||
}
|
||||
else
|
||||
# Fallback: copy from container
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${MINIO_CONTAINER}$"; then
|
||||
docker cp "${MINIO_CONTAINER}:/data/." "$minio_dir/"
|
||||
else
|
||||
log_warning "MinIO container not running and mc not available, using volume backup"
|
||||
backup_minio_volume
|
||||
return
|
||||
fi
|
||||
fi
|
||||
|
||||
# Compress the data
|
||||
tar czf "${BACKUP_DIR}/minio_data.tar.gz" -C "$minio_dir" .
|
||||
rm -rf "$minio_dir"
|
||||
|
||||
log_success "MinIO backed up to minio_data.tar.gz"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Main backup logic
|
||||
# =============================================================================
|
||||
|
||||
# Save metadata
|
||||
cat > "${BACKUP_DIR}/metadata.json" << EOF
|
||||
{
|
||||
"timestamp": "$TIMESTAMP",
|
||||
"mode": "$MODE",
|
||||
"project_name": "$PROJECT_NAME",
|
||||
"volume_prefix": "$VOLUME_PREFIX",
|
||||
"postgres_db": "$POSTGRES_DB",
|
||||
"vespa_index": "$VESPA_INDEX",
|
||||
"components": {
|
||||
"postgres": $BACKUP_POSTGRES,
|
||||
"vespa": $BACKUP_VESPA,
|
||||
"minio": $BACKUP_MINIO
|
||||
}
|
||||
}
|
||||
EOF
|
||||
|
||||
# Run backups based on mode
|
||||
if [[ "$MODE" == "volume" ]]; then
|
||||
log_info "Using volume-based backup"
|
||||
|
||||
# Stop services for consistent backup
|
||||
log_info "Stopping services for consistent backup..."
|
||||
if $BACKUP_POSTGRES; then
|
||||
stop_service "relational_db"
|
||||
fi
|
||||
if $BACKUP_VESPA; then
|
||||
stop_service "index"
|
||||
fi
|
||||
if $BACKUP_MINIO; then
|
||||
stop_service "minio"
|
||||
fi
|
||||
|
||||
# Perform backups
|
||||
if $BACKUP_POSTGRES; then
|
||||
backup_postgres_volume || log_warning "PostgreSQL backup failed"
|
||||
fi
|
||||
|
||||
if $BACKUP_VESPA; then
|
||||
backup_vespa_volume || log_warning "Vespa backup failed"
|
||||
fi
|
||||
|
||||
if $BACKUP_MINIO; then
|
||||
backup_minio_volume || log_warning "MinIO backup failed"
|
||||
fi
|
||||
|
||||
# Restart services unless --no-restart was specified
|
||||
if [[ "$NO_RESTART" != true ]]; then
|
||||
start_services
|
||||
else
|
||||
log_info "Skipping service restart (--no-restart specified)"
|
||||
log_info "Stopped services: ${STOPPED_SERVICES[*]}"
|
||||
fi
|
||||
else
|
||||
log_info "Using API-based backup (services must be running)"
|
||||
|
||||
if $BACKUP_POSTGRES; then
|
||||
backup_postgres_api || log_warning "PostgreSQL backup failed"
|
||||
fi
|
||||
|
||||
if $BACKUP_VESPA; then
|
||||
backup_vespa_api || log_warning "Vespa backup failed"
|
||||
fi
|
||||
|
||||
if $BACKUP_MINIO; then
|
||||
backup_minio_api || log_warning "MinIO backup failed"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Calculate total size
|
||||
TOTAL_SIZE=$(du -sh "$BACKUP_DIR" | cut -f1)
|
||||
|
||||
log_success "==================================="
|
||||
log_success "Backup completed!"
|
||||
log_success "Location: $BACKUP_DIR"
|
||||
log_success "Total size: $TOTAL_SIZE"
|
||||
log_success "==================================="
|
||||
|
||||
# Create a symlink to latest backup
|
||||
ln -sfn "$TIMESTAMP" "${OUTPUT_DIR}/latest"
|
||||
log_info "Symlink created: ${OUTPUT_DIR}/latest -> $TIMESTAMP"
|
||||
580
backend/scripts/dump/restore_data.sh
Executable file
580
backend/scripts/dump/restore_data.sh
Executable file
@@ -0,0 +1,580 @@
|
||||
#!/bin/bash
|
||||
# =============================================================================
|
||||
# Onyx Data Restore Script
|
||||
# =============================================================================
|
||||
# This script restores PostgreSQL, Vespa, and MinIO data from a backup.
|
||||
#
|
||||
# The script auto-detects the backup mode based on files present:
|
||||
# - *_volume.tar.gz files -> volume restore
|
||||
# - postgres_dump.backup / vespa_documents.jsonl -> api restore
|
||||
#
|
||||
# Usage:
|
||||
# ./restore_data.sh [OPTIONS]
|
||||
#
|
||||
# Options:
|
||||
# --input <dir> Backup directory (required, or use 'latest')
|
||||
# --project <name> Docker Compose project name (default: onyx)
|
||||
# --volume-prefix <name> Volume name prefix (default: same as project name)
|
||||
# --compose-dir <dir> Docker Compose directory (for service management)
|
||||
# --postgres-only Only restore PostgreSQL
|
||||
# --vespa-only Only restore Vespa
|
||||
# --minio-only Only restore MinIO
|
||||
# --no-minio Skip MinIO restore
|
||||
# --no-restart Don't restart services after restore (volume mode)
|
||||
# --force Skip confirmation prompts
|
||||
# --help Show this help message
|
||||
#
|
||||
# Examples:
|
||||
# ./restore_data.sh --input ./onyx_backup/latest
|
||||
# ./restore_data.sh --input ./onyx_backup/20240115_120000 --force
|
||||
# ./restore_data.sh --input ./onyx_backup/latest --postgres-only
|
||||
# ./restore_data.sh --input ./backup --volume-prefix myprefix
|
||||
#
|
||||
# WARNING: This will overwrite existing data in the target instance!
|
||||
# =============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Default configuration
|
||||
INPUT_DIR=""
|
||||
PROJECT_NAME="onyx"
|
||||
VOLUME_PREFIX="" # Will default to PROJECT_NAME if not set
|
||||
COMPOSE_DIR="" # Docker Compose directory for service management
|
||||
RESTORE_POSTGRES=true
|
||||
RESTORE_VESPA=true
|
||||
RESTORE_MINIO=true
|
||||
FORCE=false
|
||||
NO_RESTART=false
|
||||
|
||||
# PostgreSQL defaults
|
||||
POSTGRES_USER="${POSTGRES_USER:-postgres}"
|
||||
POSTGRES_PASSWORD="${POSTGRES_PASSWORD:-password}"
|
||||
POSTGRES_DB="${POSTGRES_DB:-postgres}"
|
||||
POSTGRES_PORT="${POSTGRES_PORT:-5432}"
|
||||
|
||||
# Vespa defaults
|
||||
VESPA_HOST="${VESPA_HOST:-localhost}"
|
||||
VESPA_PORT="${VESPA_PORT:-8081}"
|
||||
VESPA_INDEX="${VESPA_INDEX:-danswer_index}"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
log_info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||
}
|
||||
|
||||
log_warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
show_help() {
|
||||
head -36 "$0" | tail -33
|
||||
exit 0
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--input)
|
||||
INPUT_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--project)
|
||||
PROJECT_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--volume-prefix)
|
||||
VOLUME_PREFIX="$2"
|
||||
shift 2
|
||||
;;
|
||||
--compose-dir)
|
||||
COMPOSE_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--no-restart)
|
||||
NO_RESTART=true
|
||||
shift
|
||||
;;
|
||||
--postgres-only)
|
||||
RESTORE_POSTGRES=true
|
||||
RESTORE_VESPA=false
|
||||
RESTORE_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--vespa-only)
|
||||
RESTORE_POSTGRES=false
|
||||
RESTORE_VESPA=true
|
||||
RESTORE_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--minio-only)
|
||||
RESTORE_POSTGRES=false
|
||||
RESTORE_VESPA=false
|
||||
RESTORE_MINIO=true
|
||||
shift
|
||||
;;
|
||||
--no-minio)
|
||||
RESTORE_MINIO=false
|
||||
shift
|
||||
;;
|
||||
--force)
|
||||
FORCE=true
|
||||
shift
|
||||
;;
|
||||
--help)
|
||||
show_help
|
||||
;;
|
||||
*)
|
||||
log_error "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate input directory
|
||||
if [[ -z "$INPUT_DIR" ]]; then
|
||||
log_error "Input directory is required. Use --input <dir>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Resolve symlinks (e.g., 'latest')
|
||||
INPUT_DIR=$(cd "$INPUT_DIR" && pwd)
|
||||
|
||||
if [[ ! -d "$INPUT_DIR" ]]; then
|
||||
log_error "Input directory not found: $INPUT_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Load metadata if available
|
||||
METADATA_FILE="${INPUT_DIR}/metadata.json"
|
||||
if [[ -f "$METADATA_FILE" ]]; then
|
||||
log_info "Loading backup metadata..."
|
||||
BACKUP_MODE=$(jq -r '.mode // "unknown"' "$METADATA_FILE")
|
||||
BACKUP_TIMESTAMP=$(jq -r '.timestamp // "unknown"' "$METADATA_FILE")
|
||||
log_info " Backup timestamp: $BACKUP_TIMESTAMP"
|
||||
log_info " Backup mode: $BACKUP_MODE"
|
||||
fi
|
||||
|
||||
# Set VOLUME_PREFIX to PROJECT_NAME if not specified
|
||||
if [[ -z "$VOLUME_PREFIX" ]]; then
|
||||
VOLUME_PREFIX="$PROJECT_NAME"
|
||||
fi
|
||||
|
||||
log_info "Volume prefix: $VOLUME_PREFIX"
|
||||
|
||||
# Track which services were stopped
|
||||
STOPPED_SERVICES=()
|
||||
|
||||
# =============================================================================
|
||||
# Service management functions
|
||||
# =============================================================================
|
||||
|
||||
stop_service() {
|
||||
local service=$1
|
||||
local container="${PROJECT_NAME}-${service}-1"
|
||||
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
log_info "Stopping ${service}..."
|
||||
if [[ -n "$COMPOSE_DIR" ]]; then
|
||||
docker compose -p "$PROJECT_NAME" -f "${COMPOSE_DIR}/docker-compose.yml" stop "$service" 2>/dev/null || \
|
||||
docker stop "$container"
|
||||
else
|
||||
docker stop "$container"
|
||||
fi
|
||||
STOPPED_SERVICES+=("$service")
|
||||
fi
|
||||
}
|
||||
|
||||
start_services() {
|
||||
if [[ ${#STOPPED_SERVICES[@]} -eq 0 ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
log_info "Restarting services: ${STOPPED_SERVICES[*]}"
|
||||
|
||||
if [[ -n "$COMPOSE_DIR" ]]; then
|
||||
docker compose -p "$PROJECT_NAME" -f "${COMPOSE_DIR}/docker-compose.yml" start "${STOPPED_SERVICES[@]}" 2>/dev/null || {
|
||||
# Fallback to starting containers directly
|
||||
for service in "${STOPPED_SERVICES[@]}"; do
|
||||
docker start "${PROJECT_NAME}-${service}-1" 2>/dev/null || true
|
||||
done
|
||||
}
|
||||
else
|
||||
for service in "${STOPPED_SERVICES[@]}"; do
|
||||
docker start "${PROJECT_NAME}-${service}-1" 2>/dev/null || true
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
# Auto-detect backup mode based on files present
|
||||
detect_backup_mode() {
|
||||
if [[ -f "${INPUT_DIR}/postgres_volume.tar.gz" ]] || [[ -f "${INPUT_DIR}/vespa_volume.tar.gz" ]]; then
|
||||
echo "volume"
|
||||
elif [[ -f "${INPUT_DIR}/postgres_dump.backup" ]] || [[ -f "${INPUT_DIR}/vespa_documents.jsonl" ]]; then
|
||||
echo "api"
|
||||
else
|
||||
echo "unknown"
|
||||
fi
|
||||
}
|
||||
|
||||
DETECTED_MODE=$(detect_backup_mode)
|
||||
log_info "Detected backup mode: $DETECTED_MODE"
|
||||
|
||||
# Get container names
|
||||
POSTGRES_CONTAINER="${PROJECT_NAME}-relational_db-1"
|
||||
VESPA_CONTAINER="${PROJECT_NAME}-index-1"
|
||||
MINIO_CONTAINER="${PROJECT_NAME}-minio-1"
|
||||
|
||||
# Confirmation prompt
|
||||
if [[ "$FORCE" != true ]]; then
|
||||
echo ""
|
||||
log_warning "==================================="
|
||||
log_warning "WARNING: This will overwrite existing data!"
|
||||
log_warning "==================================="
|
||||
echo ""
|
||||
echo "Restore configuration:"
|
||||
echo " Input directory: $INPUT_DIR"
|
||||
echo " Project name: $PROJECT_NAME"
|
||||
echo " Restore PostgreSQL: $RESTORE_POSTGRES"
|
||||
echo " Restore Vespa: $RESTORE_VESPA"
|
||||
echo " Restore MinIO: $RESTORE_MINIO"
|
||||
echo ""
|
||||
read -p "Are you sure you want to continue? (yes/no): " confirm
|
||||
if [[ "$confirm" != "yes" ]]; then
|
||||
log_info "Restore cancelled."
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# =============================================================================
|
||||
# Volume-based restore functions
|
||||
# =============================================================================
|
||||
|
||||
restore_postgres_volume() {
|
||||
log_info "Restoring PostgreSQL from volume backup..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_db_volume"
|
||||
local backup_file="${INPUT_DIR}/postgres_volume.tar.gz"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_error "PostgreSQL volume backup not found: $backup_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Remove existing volume and create new one
|
||||
log_info "Recreating PostgreSQL volume..."
|
||||
docker volume rm "$volume_name" 2>/dev/null || true
|
||||
docker volume create "$volume_name"
|
||||
|
||||
# Restore volume from tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/target" \
|
||||
-v "${INPUT_DIR}:/backup:ro" \
|
||||
alpine sh -c "cd /target && tar xzf /backup/postgres_volume.tar.gz"
|
||||
|
||||
log_success "PostgreSQL volume restored"
|
||||
}
|
||||
|
||||
restore_vespa_volume() {
|
||||
log_info "Restoring Vespa from volume backup..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_vespa_volume"
|
||||
local backup_file="${INPUT_DIR}/vespa_volume.tar.gz"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_error "Vespa volume backup not found: $backup_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Remove existing volume and create new one
|
||||
log_info "Recreating Vespa volume..."
|
||||
docker volume rm "$volume_name" 2>/dev/null || true
|
||||
docker volume create "$volume_name"
|
||||
|
||||
# Restore volume from tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/target" \
|
||||
-v "${INPUT_DIR}:/backup:ro" \
|
||||
alpine sh -c "cd /target && tar xzf /backup/vespa_volume.tar.gz"
|
||||
|
||||
log_success "Vespa volume restored"
|
||||
}
|
||||
|
||||
restore_minio_volume() {
|
||||
log_info "Restoring MinIO from volume backup..."
|
||||
|
||||
local volume_name="${VOLUME_PREFIX}_minio_data"
|
||||
local backup_file="${INPUT_DIR}/minio_volume.tar.gz"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_error "MinIO volume backup not found: $backup_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Remove existing volume and create new one
|
||||
log_info "Recreating MinIO volume..."
|
||||
docker volume rm "$volume_name" 2>/dev/null || true
|
||||
docker volume create "$volume_name"
|
||||
|
||||
# Restore volume from tar
|
||||
docker run --rm \
|
||||
-v "${volume_name}:/target" \
|
||||
-v "${INPUT_DIR}:/backup:ro" \
|
||||
alpine sh -c "cd /target && tar xzf /backup/minio_volume.tar.gz"
|
||||
|
||||
log_success "MinIO volume restored"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# API-based restore functions
|
||||
# =============================================================================
|
||||
|
||||
restore_postgres_api() {
|
||||
log_info "Restoring PostgreSQL from pg_dump backup..."
|
||||
|
||||
local backup_file="${INPUT_DIR}/postgres_dump.backup"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_error "PostgreSQL dump not found: $backup_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Check if container is running
|
||||
if ! docker ps --format '{{.Names}}' | grep -q "^${POSTGRES_CONTAINER}$"; then
|
||||
log_error "PostgreSQL container '$POSTGRES_CONTAINER' is not running"
|
||||
log_info "Please start the containers first: docker compose up -d relational_db"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Copy backup file to container
|
||||
log_info "Copying backup file to container..."
|
||||
docker cp "$backup_file" "${POSTGRES_CONTAINER}:/tmp/postgres_dump.backup"
|
||||
|
||||
# Drop and recreate database (optional, pg_restore --clean should handle this)
|
||||
log_info "Restoring database..."
|
||||
|
||||
# Use pg_restore with --clean to drop objects before recreating
|
||||
docker exec "$POSTGRES_CONTAINER" \
|
||||
pg_restore -U "$POSTGRES_USER" -d "$POSTGRES_DB" \
|
||||
--clean --if-exists --no-owner --no-privileges \
|
||||
/tmp/postgres_dump.backup 2>&1 || {
|
||||
# pg_restore may return non-zero even on success due to warnings
|
||||
log_warning "pg_restore completed with warnings (this is often normal)"
|
||||
}
|
||||
|
||||
# Cleanup
|
||||
docker exec "$POSTGRES_CONTAINER" rm -f /tmp/postgres_dump.backup
|
||||
|
||||
log_success "PostgreSQL restored"
|
||||
}
|
||||
|
||||
restore_vespa_api() {
|
||||
log_info "Restoring Vespa from JSONL backup..."
|
||||
|
||||
local backup_file="${INPUT_DIR}/vespa_documents.jsonl"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_error "Vespa backup not found: $backup_file"
|
||||
return 1
|
||||
fi
|
||||
|
||||
local endpoint="http://${VESPA_HOST}:${VESPA_PORT}/document/v1/default/${VESPA_INDEX}/docid"
|
||||
local total_docs=0
|
||||
local failed_docs=0
|
||||
|
||||
# Check if Vespa is accessible
|
||||
if ! curl -s -o /dev/null -w "%{http_code}" "http://${VESPA_HOST}:${VESPA_PORT}/state/v1/health" | grep -q "200"; then
|
||||
log_error "Cannot connect to Vespa at ${VESPA_HOST}:${VESPA_PORT}"
|
||||
log_info "Please ensure Vespa is running and accessible"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Wait for Vespa to be fully ready
|
||||
log_info "Waiting for Vespa to be fully ready..."
|
||||
local max_wait=60
|
||||
local waited=0
|
||||
while ! curl -s "http://${VESPA_HOST}:${VESPA_PORT}/state/v1/health" | grep -q '"status":{"code":"up"}'; do
|
||||
if [[ $waited -ge $max_wait ]]; then
|
||||
log_error "Vespa did not become ready within ${max_wait} seconds"
|
||||
return 1
|
||||
fi
|
||||
sleep 2
|
||||
waited=$((waited + 2))
|
||||
done
|
||||
|
||||
# Restore documents
|
||||
log_info "Restoring documents..."
|
||||
while IFS= read -r line; do
|
||||
if [[ -z "$line" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract document ID
|
||||
local doc_id
|
||||
doc_id=$(echo "$line" | jq -r '.update' | sed 's/.*:://')
|
||||
|
||||
# Post document
|
||||
local response
|
||||
response=$(curl -s -w "\n%{http_code}" -X POST \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$line" \
|
||||
"${endpoint}/${doc_id}")
|
||||
|
||||
local http_code
|
||||
http_code=$(echo "$response" | tail -1)
|
||||
|
||||
total_docs=$((total_docs + 1))
|
||||
|
||||
if [[ "$http_code" != "200" ]]; then
|
||||
failed_docs=$((failed_docs + 1))
|
||||
if [[ $failed_docs -le 5 ]]; then
|
||||
log_warning "Failed to restore document $doc_id (HTTP $http_code)"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Progress update
|
||||
if [[ $((total_docs % 100)) -eq 0 ]]; then
|
||||
log_info " Restored $total_docs documents..."
|
||||
fi
|
||||
done < "$backup_file"
|
||||
|
||||
if [[ $failed_docs -gt 0 ]]; then
|
||||
log_warning "Vespa restored with $failed_docs failures out of $total_docs documents"
|
||||
else
|
||||
log_success "Vespa restored ($total_docs documents)"
|
||||
fi
|
||||
}
|
||||
|
||||
restore_minio_api() {
|
||||
log_info "Restoring MinIO data..."
|
||||
|
||||
local backup_file="${INPUT_DIR}/minio_data.tar.gz"
|
||||
|
||||
if [[ ! -f "$backup_file" ]]; then
|
||||
log_warning "MinIO backup not found: $backup_file"
|
||||
# Try volume backup as fallback
|
||||
if [[ -f "${INPUT_DIR}/minio_volume.tar.gz" ]]; then
|
||||
log_info "Found volume backup, using that instead"
|
||||
restore_minio_volume
|
||||
return
|
||||
fi
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Extract to temp directory
|
||||
local temp_dir
|
||||
temp_dir=$(mktemp -d)
|
||||
tar xzf "$backup_file" -C "$temp_dir"
|
||||
|
||||
# Check if mc (MinIO client) is available
|
||||
if command -v mc &>/dev/null; then
|
||||
# Configure mc alias for local minio
|
||||
mc alias set onyx-restore http://localhost:9004 minioadmin minioadmin 2>/dev/null || true
|
||||
|
||||
# Mirror data to minio
|
||||
mc mirror "$temp_dir/" onyx-restore/ 2>/dev/null || {
|
||||
log_warning "mc mirror failed"
|
||||
}
|
||||
else
|
||||
# Fallback: copy to container
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${MINIO_CONTAINER}$"; then
|
||||
docker cp "$temp_dir/." "${MINIO_CONTAINER}:/data/"
|
||||
else
|
||||
log_error "MinIO container not running and mc not available"
|
||||
rm -rf "$temp_dir"
|
||||
return 1
|
||||
fi
|
||||
fi
|
||||
|
||||
rm -rf "$temp_dir"
|
||||
log_success "MinIO restored"
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Main restore logic
|
||||
# =============================================================================
|
||||
|
||||
log_info "Starting Onyx data restore..."
|
||||
log_info "Input directory: $INPUT_DIR"
|
||||
log_info "Project name: $PROJECT_NAME"
|
||||
|
||||
# Run restores based on detected mode
|
||||
if [[ "$DETECTED_MODE" == "volume" ]]; then
|
||||
log_info "Using volume-based restore"
|
||||
|
||||
# Stop services before restore
|
||||
log_info "Stopping services for restore..."
|
||||
if $RESTORE_POSTGRES; then
|
||||
stop_service "relational_db"
|
||||
fi
|
||||
if $RESTORE_VESPA; then
|
||||
stop_service "index"
|
||||
fi
|
||||
if $RESTORE_MINIO; then
|
||||
stop_service "minio"
|
||||
fi
|
||||
|
||||
# Perform restores
|
||||
if $RESTORE_POSTGRES; then
|
||||
restore_postgres_volume || log_warning "PostgreSQL restore failed"
|
||||
fi
|
||||
|
||||
if $RESTORE_VESPA; then
|
||||
restore_vespa_volume || log_warning "Vespa restore failed"
|
||||
fi
|
||||
|
||||
if $RESTORE_MINIO; then
|
||||
restore_minio_volume || log_warning "MinIO restore failed"
|
||||
fi
|
||||
|
||||
# Restart services unless --no-restart was specified
|
||||
if [[ "$NO_RESTART" != true ]]; then
|
||||
start_services
|
||||
else
|
||||
log_info "Skipping service restart (--no-restart specified)"
|
||||
log_info "Stopped services: ${STOPPED_SERVICES[*]}"
|
||||
fi
|
||||
|
||||
elif [[ "$DETECTED_MODE" == "api" ]]; then
|
||||
log_info "Using API-based restore"
|
||||
log_info "Services must be running for API restore"
|
||||
|
||||
if $RESTORE_POSTGRES; then
|
||||
restore_postgres_api || log_warning "PostgreSQL restore failed"
|
||||
fi
|
||||
|
||||
if $RESTORE_VESPA; then
|
||||
restore_vespa_api || log_warning "Vespa restore failed"
|
||||
fi
|
||||
|
||||
if $RESTORE_MINIO; then
|
||||
restore_minio_api || log_warning "MinIO restore failed"
|
||||
fi
|
||||
|
||||
else
|
||||
log_error "Could not detect backup mode. Ensure backup files exist in $INPUT_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
log_success "==================================="
|
||||
log_success "Restore completed!"
|
||||
log_success "==================================="
|
||||
|
||||
# Post-restore recommendations
|
||||
echo ""
|
||||
log_info "Post-restore steps:"
|
||||
log_info " 1. Run database migrations if needed: docker compose -p $PROJECT_NAME exec api_server alembic upgrade head"
|
||||
log_info " 2. Verify data integrity in the application"
|
||||
@@ -55,7 +55,7 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Ensure alembic runs in the correct directory
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
cd "$PARENT_DIR"
|
||||
@@ -63,6 +63,13 @@ cd "$PARENT_DIR"
|
||||
# Give Postgres a second to start
|
||||
sleep 1
|
||||
|
||||
# Alembic should be configured in the virtualenv for this repo
|
||||
if [[ -f "../.venv/bin/activate" ]]; then
|
||||
source ../.venv/bin/activate
|
||||
else
|
||||
echo "Warning: Python virtual environment not found at .venv/bin/activate; alembic may not work."
|
||||
fi
|
||||
|
||||
# Run Alembic upgrade
|
||||
echo "Running Alembic migration..."
|
||||
alembic upgrade head
|
||||
|
||||
80
backend/scripts/tenant_cleanup/QUICK_START_NO_BASTION.md
Normal file
80
backend/scripts/tenant_cleanup/QUICK_START_NO_BASTION.md
Normal file
@@ -0,0 +1,80 @@
|
||||
# Quick Start: Tenant Cleanup Without Bastion
|
||||
|
||||
## TL;DR - The Commands You Need
|
||||
|
||||
```bash
|
||||
# Navigate to backend directory
|
||||
cd onyx/backend
|
||||
|
||||
# Step 1: Generate CSV of tenants to clean (5-10 min)
|
||||
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_analyze_tenants.py
|
||||
|
||||
# Step 2: Mark connectors for deletion (1-2 min)
|
||||
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_mark_connectors.py \
|
||||
--csv gated_tenants_no_query_3mo_*.csv \
|
||||
--force \
|
||||
--concurrency 16
|
||||
|
||||
# ⏰ WAIT 6+ hours for background deletion to complete
|
||||
|
||||
# Step 3: Final cleanup (1-2 min)
|
||||
PYTHONPATH=. python scripts/tenant_cleanup/no_bastion_cleanup_tenants.py \
|
||||
--csv gated_tenants_no_query_3mo_*.csv \
|
||||
--force
|
||||
```
|
||||
|
||||
## What Changed?
|
||||
|
||||
Instead of the original scripts that require bastion access:
|
||||
- `analyze_current_tenants.py` → `no_bastion_analyze_tenants.py`
|
||||
- `mark_connectors_for_deletion.py` → `no_bastion_mark_connectors.py`
|
||||
- `cleanup_tenants.py` → `no_bastion_cleanup_tenants.py`
|
||||
|
||||
**No environment variables needed!** All queries run directly from pods.
|
||||
|
||||
## What You Need
|
||||
|
||||
✅ `kubectl` access to your cluster
|
||||
✅ Running `celery-worker-user-file-processing` pods
|
||||
✅ Permission to exec into pods
|
||||
|
||||
❌ No bastion host required
|
||||
❌ No SSH keys required
|
||||
❌ No environment variables required
|
||||
|
||||
## Test Your Setup
|
||||
|
||||
```bash
|
||||
# Check if you can find worker pods
|
||||
kubectl get po | grep celery-worker-user-file-processing | grep Running
|
||||
|
||||
# If you see pods, you're ready to go!
|
||||
```
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **Step 2 triggers background deletion** - the actual document deletion happens asynchronously via Celery workers
|
||||
2. **You MUST wait** between Step 2 and Step 3 for deletion to complete (can take 6+ hours)
|
||||
3. **Monitor deletion progress** with: `kubectl logs -f <celery-worker-pod>`
|
||||
4. **All scripts verify tenant status** - they'll refuse to process active (non-GATED_ACCESS) tenants
|
||||
|
||||
## Files Generated
|
||||
|
||||
- `gated_tenants_no_query_3mo_YYYYMMDD_HHMMSS.csv` - List of tenants to clean
|
||||
- `cleaned_tenants.csv` - Successfully cleaned tenants with timestamps
|
||||
|
||||
## Safety First
|
||||
|
||||
The scripts include multiple safety checks:
|
||||
- ✅ Verifies tenant status before any operation
|
||||
- ✅ Checks documents are deleted before dropping schemas
|
||||
- ✅ Prompts for confirmation on dangerous operations (unless `--force`)
|
||||
- ✅ Records all successful operations in real-time
|
||||
|
||||
## Need More Details?
|
||||
|
||||
See [NO_BASTION_README.md](./NO_BASTION_README.md) for:
|
||||
- Detailed explanations of each step
|
||||
- Troubleshooting guide
|
||||
- How it works under the hood
|
||||
- Performance characteristics
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user