mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
140 Commits
interfaces
...
dump-scrip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca3db17b08 | ||
|
|
ffd13b1104 | ||
|
|
1caa860f8e | ||
|
|
7181cc41af | ||
|
|
959b8c320d | ||
|
|
96fd0432ff | ||
|
|
4c73a03f57 | ||
|
|
e57713e376 | ||
|
|
21ea320323 | ||
|
|
bac9c48e53 | ||
|
|
7f79e34aa4 | ||
|
|
f1a81d45a1 | ||
|
|
285755a540 | ||
|
|
89003ad2d8 | ||
|
|
9f93f97259 | ||
|
|
f702eebbe7 | ||
|
|
8487e1856b | ||
|
|
a36445f840 | ||
|
|
7f30293b0e | ||
|
|
619d9528b4 | ||
|
|
6f83c669e7 | ||
|
|
c3e5f48cb4 | ||
|
|
fdf8fe391c | ||
|
|
f1d6bb9e02 | ||
|
|
9a64a717dc | ||
|
|
aa0f475e01 | ||
|
|
75238dc353 | ||
|
|
9e19803244 | ||
|
|
5cabd32638 | ||
|
|
4ccd88c331 | ||
|
|
5a80b98320 | ||
|
|
ff109d9f5c | ||
|
|
4cc276aca9 | ||
|
|
29f0df2c93 | ||
|
|
e2edcf0e0b | ||
|
|
9396fc547d | ||
|
|
c089903aad | ||
|
|
95471f64e9 | ||
|
|
13c1619d01 | ||
|
|
ddb5068847 | ||
|
|
81a4f654c2 | ||
|
|
9393c56a21 | ||
|
|
1ee96ff99c | ||
|
|
6bb00d2c6b | ||
|
|
d9cc923c6a | ||
|
|
bfbba0f036 | ||
|
|
ccf6911f97 | ||
|
|
15c9c2ba8e | ||
|
|
8b3fedf480 | ||
|
|
b8dc0749ee | ||
|
|
d6426458c6 | ||
|
|
941c4d6a54 | ||
|
|
653b65da66 | ||
|
|
503e70be02 | ||
|
|
9c19493160 | ||
|
|
933315646b | ||
|
|
d2061f8a26 | ||
|
|
6a98f0bf3c | ||
|
|
2f4d39d834 | ||
|
|
40f8bcc6f8 | ||
|
|
af9ed73f00 | ||
|
|
bf28041f4e | ||
|
|
395d5927b7 | ||
|
|
c96f24e37c | ||
|
|
070519f823 | ||
|
|
a7dc1c0f3b | ||
|
|
a947e44926 | ||
|
|
a6575b6254 | ||
|
|
31733a9c7c | ||
|
|
5415e2faf1 | ||
|
|
749f720dfd | ||
|
|
eac79cfdf2 | ||
|
|
e3b1202731 | ||
|
|
6df13cc2de | ||
|
|
682f660aa3 | ||
|
|
c4670ea86c | ||
|
|
a6757eb49f | ||
|
|
cd372fb585 | ||
|
|
45fa0d9b32 | ||
|
|
45091f2ee2 | ||
|
|
43a3cb89b9 | ||
|
|
9428eaed8d | ||
|
|
dd29d989ff | ||
|
|
f44daa2116 | ||
|
|
212cbcb683 | ||
|
|
aaad573c3f | ||
|
|
e1325e84ae | ||
|
|
e759cdd4ab | ||
|
|
2ed6607e10 | ||
|
|
ba5b9cf395 | ||
|
|
bab23f62b8 | ||
|
|
d72e2e4081 | ||
|
|
4ed2d08336 | ||
|
|
24a0ceee18 | ||
|
|
d8fba38780 | ||
|
|
5f358a1e20 | ||
|
|
00b0c23e13 | ||
|
|
2103ed9e81 | ||
|
|
2c5ab72312 | ||
|
|
672d1ca8fa | ||
|
|
a418de4287 | ||
|
|
349aba6c02 | ||
|
|
18a7bdc292 | ||
|
|
c658fd4c7d | ||
|
|
f1e87dda5b | ||
|
|
b93edb3e89 | ||
|
|
dc4e76bd64 | ||
|
|
c4242ad17a | ||
|
|
a4dee62660 | ||
|
|
2d2c76ec7b | ||
|
|
d80025138d | ||
|
|
90ec595936 | ||
|
|
f30e88a61b | ||
|
|
9c04e9269f | ||
|
|
8c65fcd193 | ||
|
|
f42e3eb823 | ||
|
|
9b76ed085c | ||
|
|
0eb4d039ae | ||
|
|
3c0b66a174 | ||
|
|
895a8e774e | ||
|
|
c14ea4dbb9 | ||
|
|
80b1e07586 | ||
|
|
59b243d585 | ||
|
|
d4ae3d1cb5 | ||
|
|
ed0a86c681 | ||
|
|
e825e5732f | ||
|
|
a93854ae70 | ||
|
|
fc8767a04f | ||
|
|
6c231e7ad1 | ||
|
|
bac751d4a9 | ||
|
|
3e0f386d5b | ||
|
|
edb6957268 | ||
|
|
0348d11fb2 | ||
|
|
fe514eada0 | ||
|
|
e7672b89bb | ||
|
|
c1494660e1 | ||
|
|
7ee3df6b92 | ||
|
|
54afed0d23 | ||
|
|
1c776fcc73 | ||
|
|
340ddce294 |
33
.github/workflows/check-lazy-imports.yml
vendored
33
.github/workflows/check-lazy-imports.yml
vendored
@@ -1,33 +0,0 @@
|
||||
name: Check Lazy Imports
|
||||
concurrency:
|
||||
group: Check-Lazy-Imports-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
25
.github/workflows/deployment.yml
vendored
25
.github/workflows/deployment.yml
vendored
@@ -89,9 +89,10 @@ jobs:
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
@@ -111,7 +112,7 @@ jobs:
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -140,7 +141,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -198,7 +199,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -306,7 +307,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -372,7 +373,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -485,7 +486,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -542,7 +543,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -650,7 +651,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -714,7 +715,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -907,7 +908,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -997,7 +998,7 @@ jobs:
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -80,12 +80,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -113,6 +114,7 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
41
.github/workflows/pr-integration-tests.yml
vendored
41
.github/workflows/pr-integration-tests.yml
vendored
@@ -43,7 +43,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -129,7 +129,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -183,7 +183,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -259,7 +259,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -274,23 +274,28 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -436,7 +441,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
4
.github/workflows/pr-jest-tests.yml
vendored
4
.github/workflows/pr-jest-tests.yml
vendored
@@ -16,12 +16,12 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
35
.github/workflows/pr-mit-integration-tests.yml
vendored
35
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -177,7 +177,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -253,7 +253,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -268,21 +268,26 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
|
||||
14
.github/workflows/pr-playwright-tests.yml
vendored
14
.github/workflows/pr-playwright-tests.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -229,13 +229,13 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'npm'
|
||||
@@ -465,12 +465,12 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
# uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
|
||||
55
.github/workflows/pr-python-checks.yml
vendored
55
.github/workflows/pr-python-checks.yml
vendored
@@ -17,24 +17,6 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
validate-requirements:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Validate requirements lock files
|
||||
run: ./backend/scripts/compile_requirements.py --check
|
||||
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
@@ -45,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -58,35 +40,10 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
ods openapi all
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
@@ -103,3 +60,9 @@ jobs:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
12
.github/workflows/pr-python-connector-tests.yml
vendored
12
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -133,12 +133,13 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -160,16 +161,20 @@ jobs:
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
- 'uv.lock'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
- 'uv.lock'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
- 'uv.lock'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -182,7 +187,8 @@ jobs:
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
--ignore backend/tests/daily/connectors/github \
|
||||
--ignore backend/tests/daily/connectors/coda
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
6
.github/workflows/pr-python-tests.yml
vendored
6
.github/workflows/pr-python-tests.yml
vendored
@@ -26,15 +26,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
23
.github/workflows/pr-quality-checks.yml
vendored
23
.github/workflows/pr-quality-checks.yml
vendored
@@ -7,6 +7,8 @@ on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
@@ -15,17 +17,10 @@ permissions:
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-quality-checks",
|
||||
]
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
@@ -35,7 +30,7 @@ jobs:
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
@@ -43,12 +38,10 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
|
||||
env:
|
||||
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
|
||||
SKIP: uv-run
|
||||
- uses: j178/prek-action@91fd7d7cf70ae1dee9f4f44e7dfa5d1073fe6623 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
prek-version: '0.2.21'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
with:
|
||||
|
||||
2
.github/workflows/release-devtools.yml
vendored
2
.github/workflows/release-devtools.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
- {goos: "darwin", goarch: "arm64"}
|
||||
- {goos: "", goarch: ""}
|
||||
steps:
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
2
.github/workflows/sync_foss.yml
vendored
2
.github/workflows/sync_foss.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tag-nightly.yml
vendored
2
.github/workflows/tag-nightly.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -53,3 +53,6 @@ node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
@@ -5,13 +5,37 @@ default_install_hook_types:
|
||||
- post-rewrite
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 569ddf04117761eb74cef7afb5143bbb96fcdfbb # frozen: 0.9.15
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
# args: ["mypy"]
|
||||
# args: ["--all-extras", "mypy"]
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
@@ -52,7 +76,7 @@ repos:
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: e6ebea0145f385056bce15041d3244c0e5e15848 # frozen: v2.7.0
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
@@ -88,12 +112,6 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
language: system
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
|
||||
7
.vscode/launch.template.jsonc
vendored
7
.vscode/launch.template.jsonc
vendored
@@ -508,7 +508,6 @@
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
@@ -554,10 +553,10 @@
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
"sync",
|
||||
"--all-extras"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
|
||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
|
||||
@@ -71,12 +71,12 @@ If using a higher version, sometimes some libraries will not be available (i.e.
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
@@ -95,33 +95,15 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/combined.txt
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
or
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
@@ -130,7 +112,7 @@ to manage your Node installations. Once installed, you can run
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
@@ -144,21 +126,15 @@ npm i
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
|
||||
@@ -7,8 +7,12 @@ Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
From onyx/backend, run:
|
||||
`alembic revision -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
Note: you cannot use the `--autogenerate` flag as the automatic schema parsing does not work.
|
||||
|
||||
Manually populate the upgrade and downgrade in your new migration.
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add is_clarification to chat_message
|
||||
|
||||
Revision ID: 18b5b2524446
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-01-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "18b5b2524446"
|
||||
down_revision = "87c52ec39f84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"is_clarification", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_clarification")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add display_name to model_configuration
|
||||
|
||||
Revision ID: 7bd55f264e1b
|
||||
Revises: e8f0d2a38171
|
||||
Create Date: 2025-12-04
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7bd55f264e1b"
|
||||
down_revision = "e8f0d2a38171"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "display_name")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""update_default_system_prompt
|
||||
|
||||
Revision ID: 87c52ec39f84
|
||||
Revises: 7bd55f264e1b
|
||||
Create Date: 2025-12-05 15:54:06.002452
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "87c52ec39f84"
|
||||
down_revision = "7bd55f264e1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :system_prompt
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't revert the system prompt on downgrade since we don't know
|
||||
# what the previous value was. The new prompt is a reasonable default.
|
||||
pass
|
||||
@@ -0,0 +1,62 @@
|
||||
"""update_default_tool_descriptions
|
||||
|
||||
Revision ID: a01bf2971c5d
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-12-16 15:21:25.656375
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a01bf2971c5d"
|
||||
down_revision = "18b5b2524446"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# new tool descriptions (12/2025)
|
||||
TOOL_DESCRIPTIONS = {
|
||||
"SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.",
|
||||
"ImageGenerationTool": (
|
||||
"The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"OktaProfileTool": (
|
||||
"The Okta Profile Action allows the agent to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -362,14 +363,29 @@ def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
|
||||
def _add_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, user_ids: list[UUID]
|
||||
) -> list[User__UserGroup]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
|
||||
for user_id in user_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction.
|
||||
|
||||
This function is idempotent - it will skip users who are already in the group
|
||||
to avoid duplicate key violations during concurrent operations or re-syncs.
|
||||
Uses ON CONFLICT DO NOTHING to keep inserts atomic under concurrency.
|
||||
"""
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
insert_stmt = (
|
||||
insert(User__UserGroup)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "user_group_id": user_group_id}
|
||||
for user_id in user_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id]
|
||||
)
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
|
||||
def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
@@ -8,13 +8,13 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -22,7 +22,6 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -75,17 +74,40 @@ def handle_simplified_chat_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
use_agentic_search=chat_message_req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -98,7 +120,8 @@ def handle_send_message_simple_with_history(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller"""
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
@@ -142,8 +165,6 @@ def handle_send_message_simple_with_history(
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
max_history_tokens = int(llm.config.max_input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
@@ -162,28 +183,36 @@ def handle_send_message_simple_with_history(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=msg_history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
)
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=rephrased_query,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
use_agentic_search=req.use_agentic_search,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=False,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -42,10 +43,20 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
@@ -57,9 +68,16 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
# If True, uses agentic search instead of basic search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
||||
@@ -56,6 +56,7 @@ from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
@@ -218,7 +219,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email)
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
@@ -226,7 +227,9 @@ def verify_email_is_invited(email: str) -> None:
|
||||
try:
|
||||
# normalized emails are now being inserted into the db
|
||||
# we can remove this normalization on read after some time has passed
|
||||
email_info_whitelist = validate_email(email_whitelist)
|
||||
email_info_whitelist = validate_email(
|
||||
email_whitelist, check_deliverability=False
|
||||
)
|
||||
except EmailNotValidError:
|
||||
continue
|
||||
|
||||
@@ -339,6 +342,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create, safe=safe, request=request
|
||||
) # type: ignore
|
||||
user_created = True
|
||||
except IntegrityError as error:
|
||||
# Race condition: another request created the same user after the
|
||||
# pre-insert existence check but before our commit.
|
||||
await self.user_db.session.rollback()
|
||||
logger.warning(
|
||||
"IntegrityError while creating user %s, assuming duplicate: %s",
|
||||
user_create.email,
|
||||
str(error),
|
||||
)
|
||||
try:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
except exceptions.UserNotExists:
|
||||
# Unexpected integrity error, surface it for handling upstream.
|
||||
raise error
|
||||
|
||||
if MULTI_TENANT:
|
||||
user_by_session = await db_session.get(User, user.id)
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
|
||||
@@ -816,10 +816,14 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
secondary_cc_pair_ids: list[int] = []
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings:
|
||||
# Include paused CC pairs during embedding swap
|
||||
# For ACTIVE_ONLY, we skip paused connectors
|
||||
include_paused = (
|
||||
secondary_search_settings.switchover_type
|
||||
!= SwitchoverType.ACTIVE_ONLY
|
||||
)
|
||||
standard_cc_pair_ids = (
|
||||
fetch_indexable_standard_connector_credential_pair_ids(
|
||||
db_session, active_cc_pairs_only=False
|
||||
db_session, active_cc_pairs_only=not include_paused
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
|
||||
@@ -105,52 +105,49 @@ S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
|
||||
- Reminder moved to the end
|
||||
```
|
||||
|
||||
|
||||
## Product considerations
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
User uploaded files are considered relevant for that point in time, it is ok if the Agent forgets about it as the chat gets long. If every uploaded file is
|
||||
constantly moved towards the end of the chat, it would degrade quality as these stack up. Even with a single file, there is some cost of making the previous
|
||||
User Message further away. This tradeoff is accepted for Projects because of the intent of the feature.
|
||||
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably and also not interfere with the last user message.
|
||||
|
||||
|
||||
## Reasons / Experiments
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
|
||||
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
|
||||
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
|
||||
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
|
||||
|
||||
Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with
|
||||
those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access.
|
||||
|
||||
Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt
|
||||
and should be very targetted for it to work reliably.
|
||||
|
||||
User uploaded files are considered relevant for that point in time, it is ok if the Agent forgets about it as the chat gets long. If every uploaded file is
|
||||
constantly moved towards the end of the chat, it would degrade quality as these stack up. Even with a single file, there is some cost of making the previous
|
||||
User Message further away. This tradeoff is accepted for Projects because of the intent of the feature.
|
||||
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in db/models.py
|
||||
|
||||
|
||||
# Appendix (just random tidbits for those interested)
|
||||
- Reminder messages are placed at the end of the prompt because all model fine tuning approaches cause the LLMs to attend very strongly to the tokens at the very
|
||||
back of the context closest to generation. This is the only way to get the LLMs to not miss critical information and for the product to be reliable. Specifically
|
||||
the built-in reminders are around citations and what tools it should call in certain situations.
|
||||
|
||||
- LLMs are able to handle changes in topic best at message boundaries. There are special tokens under the hood for this. We also use this property to slice up
|
||||
the history in the way presented above.
|
||||
|
||||
- Different LLMs vary in this but some now have a section that cannot be set via the API layer called the "System Prompt" (OpenAI terminology) which contains
|
||||
Different LLMs vary in this but some now have a section that cannot be set via the API layer called the "System Prompt" (OpenAI terminology) which contains
|
||||
information like the model cutoff date, identity, and some other basic non-changing information. The System prompt described above is in that convention called
|
||||
the "Developer Prompt". It seems the distribution of the System Prompt, by which I mean the style of wording and terms used can also affect the behavior. This
|
||||
is different between different models and not necessarily scientific so the system prompt is built from an exploration across different models. It currently
|
||||
starts with: "You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent..."
|
||||
|
||||
- The document json includes a field for the LLM to cite (it's a single number) to make citations reliable and avoid weird artifacts. It's called "document" so
|
||||
LLMs are able to handle changes in topic best at message boundaries. There are special tokens under the hood for this. We also use this property to slice up
|
||||
the history in the way presented above.
|
||||
|
||||
Reminder messages are placed at the end of the prompt because all model fine tuning approaches cause the LLMs to attend very strongly to the tokens at the very
|
||||
back of the context closest to generation. This is the only way to get the LLMs to not miss critical information and for the product to be reliable. Specifically
|
||||
the built-in reminders are around citations and what tools it should call in certain situations.
|
||||
|
||||
The document json includes a field for the LLM to cite (it's a single number) to make citations reliable and avoid weird artifacts. It's called "document" so
|
||||
that the LLM does not create weird artifacts in reasoning like "I should reference citation_id: 5 for...". It is also strategically placed so that it is easy to
|
||||
reference. It is followed by a couple short sections like the metadata and title before the long content section. It seems LLMs are still better at local
|
||||
attention despite having global access.
|
||||
|
||||
- In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
- Custom Agent prompts are also completely separate from the system prompt. Having potentially orthogonal instructions in the system prompt (both the actual
|
||||
instructions and the writing style) can greatly deteriorate the quality of the responses. There is also a product motivation to keep it close to the end of
|
||||
generation so it's strongly followed.
|
||||
|
||||
## Other related pointers
|
||||
- How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful.
|
||||
|
||||
@@ -26,6 +26,8 @@ class ChatStateContainer:
|
||||
self.answer_tokens: str | None = None
|
||||
# Store citation mapping for building citation_docs_info during partial saves
|
||||
self.citation_to_doc: dict[int, SearchDoc] = {}
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -43,12 +45,16 @@ class ChatStateContainer:
|
||||
"""Set the citation mapping from citation processor."""
|
||||
self.citation_to_doc = citation_to_doc
|
||||
|
||||
def set_is_clarification(self, is_clarification: bool) -> None:
|
||||
"""Set whether this turn is a clarification question."""
|
||||
self.is_clarification = is_clarification
|
||||
|
||||
|
||||
def run_chat_llm_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Generator[Packet, None]:
|
||||
|
||||
@@ -4,9 +4,11 @@ from collections.abc import Callable
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import is_user_admin
|
||||
from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
try_creating_kg_processing_task,
|
||||
)
|
||||
@@ -15,19 +17,24 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
)
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
|
||||
from onyx.db.llm import fetch_existing_doc_sets
|
||||
from onyx.db.llm import fetch_existing_tools
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -44,6 +51,9 @@ from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -54,10 +64,15 @@ logger = setup_logger()
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
filters: BaseFilters | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
@@ -76,7 +91,15 @@ def prepare_chat_message_request(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
filters=filters,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
)
|
||||
@@ -332,69 +355,68 @@ def extract_headers(
|
||||
return extracted_headers
|
||||
|
||||
|
||||
# TODO in case it needs to be referenced later
|
||||
# def create_temporary_persona(
|
||||
# persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
# ) -> Persona:
|
||||
# if not is_user_admin(user):
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="User is not authorized to create a persona in one shot queries",
|
||||
# )
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
# """Create a temporary Persona object from the provided configuration."""
|
||||
# persona = Persona(
|
||||
# name=persona_config.name,
|
||||
# description=persona_config.description,
|
||||
# num_chunks=persona_config.num_chunks,
|
||||
# llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
# llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
# recency_bias=persona_config.recency_bias,
|
||||
# llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
# llm_model_version_override=persona_config.llm_model_version_override,
|
||||
# )
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
# if persona_config.prompts:
|
||||
# # Use the first prompt from the override config for embedded prompt fields
|
||||
# first_prompt = persona_config.prompts[0]
|
||||
# persona.system_prompt = first_prompt.system_prompt
|
||||
# persona.task_prompt = first_prompt.task_prompt
|
||||
# persona.datetime_aware = first_prompt.datetime_aware
|
||||
if persona_config.prompts:
|
||||
# Use the first prompt from the override config for embedded prompt fields
|
||||
first_prompt = persona_config.prompts[0]
|
||||
persona.system_prompt = first_prompt.system_prompt
|
||||
persona.task_prompt = first_prompt.task_prompt
|
||||
persona.datetime_aware = first_prompt.datetime_aware
|
||||
|
||||
# persona.tools = []
|
||||
# if persona_config.custom_tools_openapi:
|
||||
# from onyx.chat.emitter import get_default_emitter
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
|
||||
# for schema in persona_config.custom_tools_openapi:
|
||||
# tools = cast(
|
||||
# list[Tool],
|
||||
# build_custom_tools_from_openapi_schema_and_headers(
|
||||
# tool_id=0, # dummy tool id
|
||||
# openapi_schema=schema,
|
||||
# emitter=get_default_emitter(),
|
||||
# ),
|
||||
# )
|
||||
# persona.tools.extend(tools)
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
tool_id=0, # dummy tool id
|
||||
openapi_schema=schema,
|
||||
emitter=get_default_emitter(),
|
||||
),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
# if persona_config.tools:
|
||||
# tool_ids = [tool.id for tool in persona_config.tools]
|
||||
# persona.tools.extend(
|
||||
# fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
# )
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
# if persona_config.tool_ids:
|
||||
# persona.tools.extend(
|
||||
# fetch_existing_tools(
|
||||
# db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
# )
|
||||
# )
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
# fetched_docs = fetch_existing_doc_sets(
|
||||
# db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
# )
|
||||
# persona.document_sets = fetched_docs
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
# return persona
|
||||
return persona
|
||||
|
||||
|
||||
def process_kg_commands(
|
||||
@@ -455,7 +477,10 @@ def load_chat_file(
|
||||
|
||||
# Extract text content if it's a text file type (not an image)
|
||||
content_text = None
|
||||
file_type = file_descriptor["type"]
|
||||
# `FileDescriptor` is often JSON-roundtripped (e.g. JSONB / API), so `type`
|
||||
# may arrive as a raw string value instead of a `ChatFileType`.
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = content.decode("utf-8")
|
||||
@@ -686,3 +711,21 @@ def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str
|
||||
return chat_session.project.instructions
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:
|
||||
"""Check if the last assistant message in chat history was a clarification question.
|
||||
|
||||
This is used in the deep research flow to determine whether to skip the
|
||||
clarification step when the user has already responded to a clarification.
|
||||
|
||||
Args:
|
||||
chat_history: List of ChatMessage objects in chronological order
|
||||
|
||||
Returns:
|
||||
True if the last assistant message has is_clarification=True, False otherwise
|
||||
"""
|
||||
for message in reversed(chat_history):
|
||||
if message.message_type == MessageType.ASSISTANT:
|
||||
return message.is_clarification
|
||||
return False
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -10,6 +7,9 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_ARGUMENTS
|
||||
from onyx.chat.llm_step import TOOL_CALL_MSG_FUNC_NAME
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
from onyx.chat.models import LlmStepResult
|
||||
@@ -19,38 +19,20 @@ from onyx.chat.prompt_utils import build_system_prompt
|
||||
from onyx.chat.prompt_utils import (
|
||||
get_default_base_system_prompt,
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ImageContentPart
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.llm.message_types import TextContentPart
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.message_types import ToolMessage
|
||||
from onyx.llm.message_types import UserMessageWithParts
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
from onyx.llm.utils import model_needs_formatting_reenabled
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
@@ -63,9 +45,7 @@ from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.tracing.framework.create import trace
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -80,9 +60,6 @@ logger = setup_logger()
|
||||
# Cycle 6: No more tools available, forced to answer
|
||||
MAX_LLM_CYCLES = 6
|
||||
|
||||
TOOL_CALL_MSG_FUNC_NAME = "function_name"
|
||||
TOOL_CALL_MSG_ARGUMENTS = "arguments"
|
||||
|
||||
|
||||
def _build_project_file_citation_mapping(
|
||||
project_file_metadata: list[ProjectFileMetadata],
|
||||
@@ -127,15 +104,23 @@ def construct_message_history(
|
||||
custom_agent_prompt: ChatMessageSimple | None,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
reminder_message: ChatMessageSimple | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
project_files: ExtractedProjectFiles | None,
|
||||
available_tokens: int,
|
||||
last_n_user_messages: int | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
if last_n_user_messages is not None:
|
||||
if last_n_user_messages <= 0:
|
||||
raise ValueError(
|
||||
"filtering chat history by last N user messages must be a value greater than 0"
|
||||
)
|
||||
|
||||
history_token_budget = available_tokens
|
||||
history_token_budget -= system_prompt.token_count
|
||||
history_token_budget -= (
|
||||
custom_agent_prompt.token_count if custom_agent_prompt else 0
|
||||
)
|
||||
history_token_budget -= project_files.total_token_count
|
||||
if project_files:
|
||||
history_token_budget -= project_files.total_token_count
|
||||
history_token_budget -= reminder_message.token_count if reminder_message else 0
|
||||
|
||||
if history_token_budget < 0:
|
||||
@@ -146,7 +131,7 @@ def construct_message_history(
|
||||
result = [system_prompt]
|
||||
if custom_agent_prompt:
|
||||
result.append(custom_agent_prompt)
|
||||
if project_files.project_file_texts:
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
@@ -155,6 +140,26 @@ def construct_message_history(
|
||||
result.append(reminder_message)
|
||||
return result
|
||||
|
||||
# If last_n_user_messages is set, filter history to only include the last n user messages
|
||||
if last_n_user_messages is not None:
|
||||
# Find all user message indices
|
||||
user_msg_indices = [
|
||||
i
|
||||
for i, msg in enumerate(simple_chat_history)
|
||||
if msg.message_type == MessageType.USER
|
||||
]
|
||||
|
||||
if not user_msg_indices:
|
||||
raise ValueError("No user message found in simple_chat_history")
|
||||
|
||||
# If we have more than n user messages, keep only the last n
|
||||
if len(user_msg_indices) > last_n_user_messages:
|
||||
# Find the index of the n-th user message from the end
|
||||
# For example, if last_n_user_messages=2, we want the 2nd-to-last user message
|
||||
nth_user_msg_index = user_msg_indices[-(last_n_user_messages)]
|
||||
# Keep everything from that user message onwards
|
||||
simple_chat_history = simple_chat_history[nth_user_msg_index:]
|
||||
|
||||
# Find the last USER message in the history
|
||||
# The history may contain tool calls and responses after the last user message
|
||||
last_user_msg_index = None
|
||||
@@ -202,7 +207,7 @@ def construct_message_history(
|
||||
break
|
||||
|
||||
# Attach project images to the last user message
|
||||
if project_files.project_image_files:
|
||||
if project_files and project_files.project_image_files:
|
||||
existing_images = last_user_message.image_files or []
|
||||
last_user_message = ChatMessageSimple(
|
||||
message=last_user_message.message,
|
||||
@@ -224,7 +229,7 @@ def construct_message_history(
|
||||
result.append(custom_agent_prompt)
|
||||
|
||||
# 3. Add project files message (inserted before last user message)
|
||||
if project_files.project_file_texts:
|
||||
if project_files and project_files.project_file_texts:
|
||||
project_message = _create_project_files_message(
|
||||
project_files, token_counter=None
|
||||
)
|
||||
@@ -274,509 +279,6 @@ def _create_project_files_message(
|
||||
)
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
) -> LanguageModelInput:
|
||||
"""Convert a list of ChatMessageSimple to LanguageModelInput format.
|
||||
|
||||
Converts ChatMessageSimple messages to ChatCompletionMessage format,
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
system_msg: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": msg.message,
|
||||
}
|
||||
messages.append(system_msg)
|
||||
|
||||
elif msg.message_type == MessageType.USER:
|
||||
# Handle user messages with potential images
|
||||
if msg.image_files:
|
||||
# Build content parts: text + images
|
||||
content_parts: list[TextContentPart | ImageContentPart] = [
|
||||
{"type": "text", "text": msg.message}
|
||||
]
|
||||
|
||||
# Add image parts
|
||||
for img_file in msg.image_files:
|
||||
if img_file.file_type == ChatFileType.IMAGE:
|
||||
try:
|
||||
image_type = get_image_type_from_bytes(img_file.content)
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
image_part: ImageContentPart = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_url},
|
||||
}
|
||||
content_parts.append(image_part)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process image file {img_file.file_id}: {e}. "
|
||||
"Skipping image."
|
||||
)
|
||||
|
||||
user_msg_with_parts: UserMessageWithParts = {
|
||||
"role": "user",
|
||||
"content": content_parts,
|
||||
}
|
||||
messages.append(user_msg_with_parts)
|
||||
else:
|
||||
# Simple text-only user message
|
||||
user_msg_text: UserMessageWithText = {
|
||||
"role": "user",
|
||||
"content": msg.message,
|
||||
}
|
||||
messages.append(user_msg_text)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": msg.message or None,
|
||||
}
|
||||
messages.append(assistant_msg)
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL:
|
||||
# Tool calls are represented as Assistant Messages with tool_calls field
|
||||
# Try to reconstruct tool call structure if we have tool_call_id
|
||||
tool_calls: list[ToolCall] = []
|
||||
if msg.tool_call_id:
|
||||
try:
|
||||
# Parse the message content (which should contain function_name and arguments)
|
||||
tool_call_data = json.loads(msg.message) if msg.message else {}
|
||||
|
||||
if (
|
||||
isinstance(tool_call_data, dict)
|
||||
and TOOL_CALL_MSG_FUNC_NAME in tool_call_data
|
||||
):
|
||||
function_name = tool_call_data.get(
|
||||
TOOL_CALL_MSG_FUNC_NAME, "unknown"
|
||||
)
|
||||
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
else:
|
||||
function_name = "unknown"
|
||||
tool_args = (
|
||||
tool_call_data if isinstance(tool_call_data, dict) else {}
|
||||
)
|
||||
|
||||
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
|
||||
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
|
||||
tool_call: ToolCall = {
|
||||
"id": msg.tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(tool_args) if tool_args else "{}",
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool call data for tool_call_id {msg.tool_call_id}: {e}. "
|
||||
"Including as content-only message."
|
||||
)
|
||||
|
||||
assistant_msg_with_tool: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": None, # The tool call is parsed, doesn't need to be duplicated in the content
|
||||
}
|
||||
if tool_calls:
|
||||
assistant_msg_with_tool["tool_calls"] = tool_calls
|
||||
messages.append(assistant_msg_with_tool)
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": msg.message,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _format_message_history_for_logging(
|
||||
message_history: LanguageModelInput,
|
||||
) -> str:
|
||||
"""Format message history for logging, with special handling for tool calls.
|
||||
|
||||
Tool calls are formatted as JSON with 4-space indentation for readability.
|
||||
"""
|
||||
formatted_lines = []
|
||||
|
||||
separator = "================================================"
|
||||
|
||||
# Handle string input
|
||||
if isinstance(message_history, str):
|
||||
formatted_lines.append("Message [string]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{message_history}")
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
# Handle sequence of messages
|
||||
for i, msg in enumerate(message_history):
|
||||
# Type guard: ensure msg is a dict-like object (TypedDict)
|
||||
if not isinstance(msg, dict):
|
||||
formatted_lines.append(f"Message {i + 1} [unknown]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{msg}")
|
||||
if i < len(message_history) - 1:
|
||||
formatted_lines.append(separator)
|
||||
continue
|
||||
|
||||
role = msg.get("role", "unknown")
|
||||
formatted_lines.append(f"Message {i + 1} [{role}]:")
|
||||
formatted_lines.append(separator)
|
||||
|
||||
if role == "system":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
formatted_lines.append(f"{content}")
|
||||
|
||||
elif role == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
formatted_lines.append(f"{content}")
|
||||
elif isinstance(content, list):
|
||||
# Handle multimodal content (text + images)
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
part_type = part.get("type")
|
||||
if part_type == "text":
|
||||
text = part.get("text", "")
|
||||
if isinstance(text, str):
|
||||
formatted_lines.append(f"{text}")
|
||||
elif part_type == "image_url":
|
||||
image_url_dict = part.get("image_url")
|
||||
if isinstance(image_url_dict, dict):
|
||||
url = image_url_dict.get("url", "")
|
||||
if isinstance(url, str):
|
||||
formatted_lines.append(f"[Image: {url[:50]}...]")
|
||||
|
||||
elif role == "assistant":
|
||||
content = msg.get("content")
|
||||
if content and isinstance(content, str):
|
||||
formatted_lines.append(f"{content}")
|
||||
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
formatted_lines.append("Tool calls:")
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
tool_call_dict: dict[str, Any] = {}
|
||||
tool_call_id = tool_call.get("id")
|
||||
tool_call_type = tool_call.get("type")
|
||||
function_dict = tool_call.get("function")
|
||||
|
||||
if tool_call_id:
|
||||
tool_call_dict["id"] = tool_call_id
|
||||
if tool_call_type:
|
||||
tool_call_dict["type"] = tool_call_type
|
||||
if isinstance(function_dict, dict):
|
||||
tool_call_dict["function"] = {
|
||||
"name": function_dict.get("name", ""),
|
||||
"arguments": function_dict.get("arguments", ""),
|
||||
}
|
||||
|
||||
tool_call_json = json.dumps(tool_call_dict, indent=4)
|
||||
formatted_lines.append(tool_call_json)
|
||||
|
||||
elif role == "tool":
|
||||
content = msg.get("content", "")
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
if isinstance(content, str) and isinstance(tool_call_id, str):
|
||||
formatted_lines.append(f"Tool call ID: {tool_call_id}")
|
||||
formatted_lines.append(f"Response: {content}")
|
||||
|
||||
# Add separator before next message (or at end)
|
||||
if i < len(message_history) - 1:
|
||||
formatted_lines.append(separator)
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
|
||||
def run_llm_step(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
emitter: Emitter,
|
||||
llm: LLM,
|
||||
turn_index: int,
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
state_container: ChatStateContainer,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
) -> tuple[LlmStepResult, int]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.info(
|
||||
f"Message history:\n{_format_message_history_for_logging(llm_msg_history)}"
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
|
||||
with generation_span(
|
||||
model=llm.config.model_name,
|
||||
model_config={
|
||||
"base_url": str(llm.config.api_base or ""),
|
||||
"model_impl": "litellm",
|
||||
},
|
||||
) as span_generation:
|
||||
span_generation.span_data.input = cast(
|
||||
Sequence[Mapping[str, Any]], llm_msg_history
|
||||
)
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=None, # TODO
|
||||
):
|
||||
if packet.usage:
|
||||
usage = packet.usage
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Should only happen once, frontend does not expect multiple
|
||||
# ReasoningStart or ReasoningDone packets.
|
||||
if delta.reasoning_content:
|
||||
accumulated_reasoning += delta.reasoning_content
|
||||
# Save reasoning incrementally to state container
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
)
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_start:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
)
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
)
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
|
||||
if tool_calls:
|
||||
tool_calls_list: list[ToolCall] = [
|
||||
{
|
||||
"id": kickoff.tool_call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": kickoff.tool_name,
|
||||
"arguments": json.dumps(kickoff.tool_args),
|
||||
},
|
||||
}
|
||||
for kickoff in tool_calls
|
||||
]
|
||||
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": accumulated_answer if accumulated_answer else None,
|
||||
"tool_calls": tool_calls_list,
|
||||
}
|
||||
span_generation.span_data.output = [assistant_msg]
|
||||
elif accumulated_answer:
|
||||
span_generation.span_data.output = [
|
||||
{"role": "assistant", "content": accumulated_answer}
|
||||
]
|
||||
# Close reasoning block if still open (stream ended with reasoning content)
|
||||
if reasoning_start:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
)
|
||||
turn_index += 1
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
|
||||
logger.debug(f"Accumulated answer: {accumulated_answer}")
|
||||
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
turn_index,
|
||||
)
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
) -> None:
|
||||
index = tool_call_delta.index
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call_delta.id:
|
||||
tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
tool_calls_in_progress[index][
|
||||
"arguments"
|
||||
] += tool_call_delta.function.arguments
|
||||
|
||||
|
||||
def _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]],
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract ToolCallKickoff objects from the tool call map.
|
||||
|
||||
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
|
||||
"""
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
try:
|
||||
# Parse arguments JSON string to dict
|
||||
tool_args = (
|
||||
json.loads(tool_call_data["arguments"])
|
||||
if tool_call_data["arguments"]
|
||||
else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
|
||||
)
|
||||
tool_args = {}
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=tool_call_data["id"],
|
||||
tool_name=tool_call_data["name"],
|
||||
tool_args=tool_args,
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def run_llm_loop(
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
@@ -790,6 +292,7 @@ def run_llm_loop(
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> None:
|
||||
with trace("run_llm_loop", metadata={"tenant_id": get_current_tenant_id()}):
|
||||
# Fix some LiteLLM issues,
|
||||
@@ -821,7 +324,7 @@ def run_llm_loop(
|
||||
|
||||
# Pass the total budget to construct_message_history, which will handle token allocation
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
tool_choice: ToolChoiceOptions = "auto"
|
||||
tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO
|
||||
collected_tool_calls: list[ToolCallInfo] = []
|
||||
# Initialize gathered_documents with project files if present
|
||||
gathered_documents: list[SearchDoc] | None = (
|
||||
@@ -837,6 +340,7 @@ def run_llm_loop(
|
||||
should_cite_documents: bool = False
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
current_tool_call_index = (
|
||||
@@ -850,14 +354,14 @@ def run_llm_loop(
|
||||
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
|
||||
if not final_tools:
|
||||
raise ValueError(f"Tool {forced_tool_id} not found in tools")
|
||||
tool_choice = "required"
|
||||
tool_choice = ToolChoiceOptions.REQUIRED
|
||||
forced_tool_id = None
|
||||
elif llm_cycle_count == MAX_LLM_CYCLES - 1 or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = "none"
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
else:
|
||||
tool_choice = "auto"
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
@@ -939,13 +443,12 @@ def run_llm_loop(
|
||||
available_tokens=available_tokens,
|
||||
)
|
||||
|
||||
# This calls the LLM, passes in the emitter which can collect packets like reasoning, answers, etc.
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
llm_step_result, current_tool_call_index = run_llm_step(
|
||||
step_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_choice=tool_choice,
|
||||
emitter=emitter,
|
||||
llm=llm,
|
||||
turn_index=current_tool_call_index,
|
||||
citation_processor=citation_processor,
|
||||
@@ -954,8 +457,21 @@ def run_llm_loop(
|
||||
# immediately yield the full set of found documents. This gives us the option to show the
|
||||
# final set of documents immediately if desired.
|
||||
final_documents=gathered_documents,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, current_tool_call_index = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
@@ -976,8 +492,13 @@ def run_llm_loop(
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
citation_processor=citation_processor,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
|
||||
# Track if search tool was called (for skipping query expansion on subsequent calls)
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
@@ -991,9 +512,6 @@ def run_llm_loop(
|
||||
f"Tool '{tool_call.tool_name}' not found in tools list"
|
||||
)
|
||||
|
||||
# Collect tool call info with reasoning tokens from this LLM step
|
||||
# All tool calls from the same loop iteration share the same reasoning tokens
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
@@ -1110,10 +628,6 @@ def run_llm_loop(
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
|
||||
# Note: All state (answer, reasoning, citations, tool_calls) is saved incrementally
|
||||
# in state_container. The process_message layer will persist to DB.
|
||||
|
||||
# Signal completion
|
||||
emitter.emit(
|
||||
Packet(turn_index=current_tool_call_index, obj=OverallStop(type="stop"))
|
||||
)
|
||||
|
||||
518
backend/onyx/chat/llm_step.py
Normal file
518
backend/onyx/chat/llm_step.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.models import AssistantMessage
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import FunctionCall
|
||||
from onyx.llm.models import ImageContentPart
|
||||
from onyx.llm.models import ImageUrlDetail
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import TextContentPart
|
||||
from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDone
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TOOL_CALL_MSG_FUNC_NAME = "function_name"
|
||||
TOOL_CALL_MSG_ARGUMENTS = "arguments"
|
||||
|
||||
|
||||
def _format_message_history_for_logging(
|
||||
message_history: LanguageModelInput,
|
||||
) -> str:
|
||||
"""Format message history for logging, with special handling for tool calls.
|
||||
|
||||
Tool calls are formatted as JSON with 4-space indentation for readability.
|
||||
"""
|
||||
formatted_lines = []
|
||||
|
||||
separator = "================================================"
|
||||
|
||||
# Handle string input
|
||||
if isinstance(message_history, str):
|
||||
formatted_lines.append("Message [string]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{message_history}")
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
# Handle sequence of messages
|
||||
for i, msg in enumerate(message_history):
|
||||
if isinstance(msg, SystemMessage):
|
||||
formatted_lines.append(f"Message {i + 1} [system]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{msg.content}")
|
||||
|
||||
elif isinstance(msg, UserMessage):
|
||||
formatted_lines.append(f"Message {i + 1} [user]:")
|
||||
formatted_lines.append(separator)
|
||||
if isinstance(msg.content, str):
|
||||
formatted_lines.append(f"{msg.content}")
|
||||
elif isinstance(msg.content, list):
|
||||
# Handle multimodal content (text + images)
|
||||
for part in msg.content:
|
||||
if isinstance(part, TextContentPart):
|
||||
formatted_lines.append(f"{part.text}")
|
||||
elif isinstance(part, ImageContentPart):
|
||||
url = part.image_url.url
|
||||
formatted_lines.append(f"[Image: {url[:50]}...]")
|
||||
|
||||
elif isinstance(msg, AssistantMessage):
|
||||
formatted_lines.append(f"Message {i + 1} [assistant]:")
|
||||
formatted_lines.append(separator)
|
||||
if msg.content:
|
||||
formatted_lines.append(f"{msg.content}")
|
||||
|
||||
if msg.tool_calls:
|
||||
formatted_lines.append("Tool calls:")
|
||||
for tool_call in msg.tool_calls:
|
||||
tool_call_dict: dict[str, Any] = {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments,
|
||||
},
|
||||
}
|
||||
tool_call_json = json.dumps(tool_call_dict, indent=4)
|
||||
formatted_lines.append(tool_call_json)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
formatted_lines.append(f"Message {i + 1} [tool]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"Tool call ID: {msg.tool_call_id}")
|
||||
formatted_lines.append(f"Response: {msg.content}")
|
||||
|
||||
else:
|
||||
# Fallback for unknown message types
|
||||
formatted_lines.append(f"Message {i + 1} [unknown]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{msg}")
|
||||
|
||||
# Add separator before next message (or at end)
|
||||
if i < len(message_history) - 1:
|
||||
formatted_lines.append(separator)
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
) -> None:
|
||||
index = tool_call_delta.index
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call_delta.id:
|
||||
tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
tool_calls_in_progress[index][
|
||||
"arguments"
|
||||
] += tool_call_delta.function.arguments
|
||||
|
||||
|
||||
def _extract_tool_call_kickoffs(
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]],
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract ToolCallKickoff objects from the tool call map.
|
||||
|
||||
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
|
||||
"""
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tool_call_data in id_to_tool_call_map.values():
|
||||
if tool_call_data.get("id") and tool_call_data.get("name"):
|
||||
try:
|
||||
# Parse arguments JSON string to dict
|
||||
tool_args = (
|
||||
json.loads(tool_call_data["arguments"])
|
||||
if tool_call_data["arguments"]
|
||||
else {}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, try empty dict, most tools would fail though
|
||||
logger.error(
|
||||
f"Failed to parse tool call arguments: {tool_call_data['arguments']}"
|
||||
)
|
||||
tool_args = {}
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=tool_call_data["id"],
|
||||
tool_name=tool_call_data["name"],
|
||||
tool_args=tool_args,
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
) -> LanguageModelInput:
|
||||
"""Convert a list of ChatMessageSimple to LanguageModelInput format.
|
||||
|
||||
Converts ChatMessageSimple messages to ChatCompletionMessage format,
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
system_msg = SystemMessage(
|
||||
role="system",
|
||||
content=msg.message,
|
||||
)
|
||||
messages.append(system_msg)
|
||||
|
||||
elif msg.message_type == MessageType.USER:
|
||||
# Handle user messages with potential images
|
||||
if msg.image_files:
|
||||
# Build content parts: text + images
|
||||
content_parts: list[TextContentPart | ImageContentPart] = [
|
||||
TextContentPart(
|
||||
type="text",
|
||||
text=msg.message,
|
||||
)
|
||||
]
|
||||
|
||||
# Add image parts
|
||||
for img_file in msg.image_files:
|
||||
if img_file.file_type == ChatFileType.IMAGE:
|
||||
try:
|
||||
image_type = get_image_type_from_bytes(img_file.content)
|
||||
base64_data = img_file.to_base64()
|
||||
image_url = f"data:{image_type};base64,{base64_data}"
|
||||
|
||||
image_part = ImageContentPart(
|
||||
type="image_url",
|
||||
image_url=ImageUrlDetail(
|
||||
url=image_url,
|
||||
detail=None,
|
||||
),
|
||||
)
|
||||
content_parts.append(image_part)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process image file {img_file.file_id}: {e}. "
|
||||
"Skipping image."
|
||||
)
|
||||
user_msg = UserMessage(
|
||||
role="user",
|
||||
content=content_parts,
|
||||
)
|
||||
messages.append(user_msg)
|
||||
else:
|
||||
# Simple text-only user message
|
||||
user_msg_text = UserMessage(
|
||||
role="user",
|
||||
content=msg.message,
|
||||
)
|
||||
messages.append(user_msg_text)
|
||||
|
||||
elif msg.message_type == MessageType.ASSISTANT:
|
||||
assistant_msg = AssistantMessage(
|
||||
role="assistant",
|
||||
content=msg.message or None,
|
||||
tool_calls=None,
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL:
|
||||
# Tool calls are represented as Assistant Messages with tool_calls field
|
||||
# Try to reconstruct tool call structure if we have tool_call_id
|
||||
tool_calls: list[ToolCall] = []
|
||||
if msg.tool_call_id:
|
||||
try:
|
||||
# Parse the message content (which should contain function_name and arguments)
|
||||
tool_call_data = json.loads(msg.message) if msg.message else {}
|
||||
|
||||
if (
|
||||
isinstance(tool_call_data, dict)
|
||||
and TOOL_CALL_MSG_FUNC_NAME in tool_call_data
|
||||
):
|
||||
function_name = tool_call_data.get(
|
||||
TOOL_CALL_MSG_FUNC_NAME, "unknown"
|
||||
)
|
||||
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
|
||||
else:
|
||||
function_name = "unknown"
|
||||
tool_args = (
|
||||
tool_call_data if isinstance(tool_call_data, dict) else {}
|
||||
)
|
||||
|
||||
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
|
||||
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
|
||||
tool_call = ToolCall(
|
||||
id=msg.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=json.dumps(tool_args) if tool_args else "{}",
|
||||
),
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(
|
||||
f"Failed to parse tool call data for tool_call_id {msg.tool_call_id}: {e}. "
|
||||
"Including as content-only message."
|
||||
)
|
||||
|
||||
assistant_msg_with_tool = AssistantMessage(
|
||||
role="assistant",
|
||||
content=None, # The tool call is parsed, doesn't need to be duplicated in the content
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
messages.append(assistant_msg_with_tool)
|
||||
|
||||
elif msg.message_type == MessageType.TOOL_CALL_RESPONSE:
|
||||
if not msg.tool_call_id:
|
||||
raise ValueError(
|
||||
f"Tool call response message encountered but tool_call_id is not available. Message: {msg}"
|
||||
)
|
||||
|
||||
tool_msg = ToolMessage(
|
||||
role="tool",
|
||||
content=msg.message,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
)
|
||||
messages.append(tool_msg)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def run_llm_step(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
tool_choice: ToolChoiceOptions,
|
||||
llm: LLM,
|
||||
turn_index: int,
|
||||
citation_processor: DynamicCitationProcessor,
|
||||
state_container: ChatStateContainer,
|
||||
final_documents: list[SearchDoc] | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
|
||||
# The second return value is for the turn index because reasoning counts on the frontend as a turn
|
||||
# TODO this is maybe ok but does not align well with the backend logic too well
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.info(
|
||||
f"Message history:\n{_format_message_history_for_logging(llm_msg_history)}"
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
|
||||
with generation_span(
|
||||
model=llm.config.model_name,
|
||||
model_config={
|
||||
"base_url": str(llm.config.api_base or ""),
|
||||
"model_impl": "litellm",
|
||||
},
|
||||
) as span_generation:
|
||||
span_generation.span_data.input = cast(
|
||||
Sequence[Mapping[str, Any]], llm_msg_history
|
||||
)
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=None, # TODO
|
||||
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
|
||||
user_identity=user_identity,
|
||||
):
|
||||
if packet.usage:
|
||||
usage = packet.usage
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Should only happen once, frontend does not expect multiple
|
||||
# ReasoningStart or ReasoningDone packets.
|
||||
if delta.reasoning_content:
|
||||
accumulated_reasoning += delta.reasoning_content
|
||||
# Save reasoning incrementally to state container
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDelta(reasoning=delta.reasoning_content),
|
||||
)
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
reasoning_start = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
|
||||
if tool_calls:
|
||||
tool_calls_list: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=kickoff.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=kickoff.tool_name,
|
||||
arguments=json.dumps(kickoff.tool_args),
|
||||
),
|
||||
)
|
||||
for kickoff in tool_calls
|
||||
]
|
||||
|
||||
assistant_msg: AssistantMessage = AssistantMessage(
|
||||
role="assistant",
|
||||
content=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls_list,
|
||||
)
|
||||
span_generation.span_data.output = [assistant_msg.model_dump()]
|
||||
elif accumulated_answer:
|
||||
assistant_msg_no_tools = AssistantMessage(
|
||||
role="assistant",
|
||||
content=accumulated_answer,
|
||||
tool_calls=None,
|
||||
)
|
||||
span_generation.span_data.output = [assistant_msg_no_tools.model_dump()]
|
||||
# Close reasoning block if still open (stream ended with reasoning content)
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
turn_index += 1
|
||||
|
||||
# Flush any remaining content from citation processor
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(None):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
turn_index=turn_index,
|
||||
obj=result,
|
||||
)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
|
||||
logger.debug(f"Accumulated answer: {accumulated_answer}")
|
||||
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
turn_index,
|
||||
)
|
||||
@@ -5,10 +5,12 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -100,6 +102,11 @@ class MessageResponseIDInfo(BaseModel):
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
error_code: str | None = (
|
||||
None # e.g., "RATE_LIMIT", "AUTH_ERROR", "TOOL_CALL_FAILED"
|
||||
)
|
||||
is_retryable: bool = True # Hint to frontend if retry might help
|
||||
details: dict | None = None # Additional context (tool name, model name, etc.)
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
@@ -125,6 +132,35 @@ class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
datetime_aware: bool = True
|
||||
include_citations: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
# Note: prompt_ids removed - prompts are now embedded in personas
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
@@ -12,6 +13,7 @@ from onyx.chat.chat_state import run_chat_llm_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import get_custom_agent_prompt
|
||||
from onyx.chat.chat_utils import is_last_assistant_message_clarification
|
||||
from onyx.chat.chat_utils import load_all_chat_files
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
@@ -26,6 +28,8 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
@@ -42,6 +46,7 @@ from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
@@ -49,6 +54,7 @@ from onyx.file_store.utils import verify_user_files
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -58,10 +64,12 @@ from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
from onyx.tools.tool_constructor import CustomToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolUsage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.timing import log_function_time
|
||||
@@ -75,6 +83,10 @@ ERROR_TYPE_CANCELLED = "cancelled"
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
@@ -202,6 +214,46 @@ def _extract_project_file_texts_and_images(
|
||||
)
|
||||
|
||||
|
||||
def _get_project_search_availability(
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
has_project_file_texts: bool,
|
||||
forced_tool_ids: list[int] | None,
|
||||
search_tool_id: int | None,
|
||||
) -> SearchToolUsage:
|
||||
"""Determine search tool availability based on project context.
|
||||
|
||||
Args:
|
||||
project_id: The project ID if the user is in a project
|
||||
persona_id: The persona ID to check if it's the default persona
|
||||
has_project_file_texts: Whether project files are loaded in context
|
||||
forced_tool_ids: List of forced tool IDs (may be mutated to remove search tool)
|
||||
search_tool_id: The search tool ID to check against
|
||||
|
||||
Returns:
|
||||
SearchToolUsage setting indicating how search should be used
|
||||
"""
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
# If no files are uploaded, search should remain enabled
|
||||
search_usage_forcing_setting = SearchToolUsage.AUTO
|
||||
if project_id:
|
||||
if bool(persona_id is DEFAULT_PERSONA_ID and has_project_file_texts):
|
||||
search_usage_forcing_setting = SearchToolUsage.DISABLED
|
||||
# Remove search tool from forced_tool_ids if it's present
|
||||
if forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
forced_tool_ids[:] = [
|
||||
tool_id for tool_id in forced_tool_ids if tool_id != search_tool_id
|
||||
]
|
||||
elif forced_tool_ids and search_tool_id and search_tool_id in forced_tool_ids:
|
||||
search_usage_forcing_setting = SearchToolUsage.ENABLED
|
||||
return search_usage_forcing_setting
|
||||
|
||||
|
||||
def _initialize_chat_session(
|
||||
message_text: str,
|
||||
files: list[FileDescriptor],
|
||||
@@ -258,10 +310,17 @@ def stream_chat_message_objects(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
@@ -274,10 +333,15 @@ def stream_chat_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
|
||||
llm: LLM
|
||||
llm: LLM | None = None
|
||||
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
user.email
|
||||
if user is not None and getattr(user, "email", None)
|
||||
else (str(user_id) if user_id else "anonymous_user")
|
||||
)
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
@@ -288,8 +352,14 @@ def stream_chat_message_objects(
|
||||
|
||||
message_text = new_msg_req.message
|
||||
chat_session_id = new_msg_req.chat_session_id
|
||||
user_identity = LLMUserIdentity(
|
||||
user_id=llm_user_identifier, session_id=str(chat_session_id)
|
||||
)
|
||||
parent_id = new_msg_req.parent_message_id
|
||||
user_selected_filters = new_msg_req.filters
|
||||
reference_doc_ids = new_msg_req.search_doc_ids
|
||||
retrieval_options = new_msg_req.retrieval_options
|
||||
new_msg_req.alternate_assistant_id
|
||||
user_selected_filters = retrieval_options.filters if retrieval_options else None
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
@@ -304,6 +374,11 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if reference_doc_ids is None and retrieval_options is None:
|
||||
raise RuntimeError(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -372,19 +447,23 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# There are cases where the internal search tool should be disabled
|
||||
# If the user is in a project, it should not use other sources / generic search
|
||||
# If they are in a project but using a custom agent, it should use the agent setup
|
||||
# (which means it can use search)
|
||||
# However if in a project and there are more files than can fit in the context,
|
||||
# it should use the search tool with the project filter on
|
||||
disable_internal_search = bool(
|
||||
chat_session.project_id
|
||||
and persona.id is DEFAULT_PERSONA_ID
|
||||
and (
|
||||
extracted_project_files.project_file_texts
|
||||
or not extracted_project_files.project_as_filter
|
||||
)
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
search_tool_id = next(
|
||||
(tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID),
|
||||
None,
|
||||
)
|
||||
|
||||
# This may also mutate the new_msg_req.forced_tool_ids
|
||||
# This logic is specifically for projects
|
||||
search_usage_forcing_setting = _get_project_search_availability(
|
||||
project_id=chat_session.project_id,
|
||||
persona_id=persona.id,
|
||||
has_project_file_texts=bool(extracted_project_files.project_file_texts),
|
||||
forced_tool_ids=new_msg_req.forced_tool_ids,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
|
||||
emitter = get_default_emitter()
|
||||
@@ -413,7 +492,7 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
disable_internal_search=disable_internal_search,
|
||||
search_usage_forcing_setting=search_usage_forcing_setting,
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
@@ -438,10 +517,6 @@ def stream_chat_message_objects(
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Build a mapping of tool_id to tool_name for history reconstruction
|
||||
all_tools = get_tools(db_session)
|
||||
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop
|
||||
simple_chat_history = convert_chat_history(
|
||||
@@ -471,24 +546,50 @@ def stream_chat_message_objects(
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
# Note: DB session is not thread safe but nothing else uses it and the
|
||||
# reference is passed directly so it's ok.
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_llm_loop,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0] if new_msg_req.forced_tool_ids else None
|
||||
),
|
||||
)
|
||||
if os.environ.get("ENABLE_DEEP_RESEARCH_LOOP"): # Dev only feature flag for now
|
||||
if chat_session.project_id:
|
||||
raise RuntimeError("Deep research is not supported for projects")
|
||||
|
||||
# Skip clarification if the last assistant message was a clarification
|
||||
# (user has already responded to a clarification question)
|
||||
skip_clarification = is_last_assistant_message_clarification(chat_history)
|
||||
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
skip_clarification=skip_clarification,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
else:
|
||||
yield from run_chat_llm_with_state_containers(
|
||||
run_llm_loop,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
simple_chat_history=simple_chat_history,
|
||||
tools=tools,
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
forced_tool_id=(
|
||||
new_msg_req.forced_tool_ids[0]
|
||||
if new_msg_req.forced_tool_ids
|
||||
else None
|
||||
),
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
@@ -532,13 +633,18 @@ def stream_chat_message_objects(
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=True,
|
||||
)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
@@ -548,9 +654,17 @@ def stream_chat_message_objects(
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
client_error_msg = client_error_msg.replace(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
@@ -559,7 +673,24 @@ def stream_chat_message_objects(
|
||||
llm.config.api_key, "[REDACTED_API_KEY]"
|
||||
)
|
||||
|
||||
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
||||
yield StreamingError(
|
||||
error=client_error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code=error_code,
|
||||
is_retryable=is_retryable,
|
||||
details={
|
||||
"model": llm.config.model_name,
|
||||
"provider": llm.config.model_provider,
|
||||
},
|
||||
)
|
||||
else:
|
||||
# LLM was never initialized - early failure
|
||||
yield StreamingError(
|
||||
error="Failed to initialize the chat. Please check your configuration and try again.",
|
||||
stack_trace=stack_trace,
|
||||
error_code="INIT_FAILED",
|
||||
is_retryable=True,
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
@@ -10,17 +10,18 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.chat_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -141,20 +142,12 @@ def build_system_prompt(
|
||||
if open_ai_formatting_enabled:
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
|
||||
try:
|
||||
citation_guidance = (
|
||||
REQUIRE_CITATION_GUIDANCE
|
||||
if should_cite_documents or include_all_guidance
|
||||
else ""
|
||||
)
|
||||
system_prompt = system_prompt.format(
|
||||
citation_reminder_or_empty=citation_guidance
|
||||
)
|
||||
except Exception:
|
||||
# Even if the prompt is modified and there is not an explicit spot for citations, always require it
|
||||
# This is more a product decision as it's likely better to always enforce citations
|
||||
if should_cite_documents or include_all_guidance:
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
# Replace citation guidance placeholder if present
|
||||
system_prompt, should_append_citation_guidance = replace_citation_guidance_tag(
|
||||
system_prompt,
|
||||
should_cite_documents=should_cite_documents,
|
||||
include_all_guidance=include_all_guidance,
|
||||
)
|
||||
|
||||
company_context = get_company_context()
|
||||
if company_context or memories:
|
||||
@@ -166,7 +159,9 @@ def build_system_prompt(
|
||||
memory.strip() for memory in memories if memory.strip()
|
||||
)
|
||||
|
||||
if should_cite_documents or include_all_guidance:
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
if should_append_citation_guidance:
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
|
||||
if include_all_guidance:
|
||||
|
||||
@@ -148,6 +148,7 @@ def save_chat_turn(
|
||||
citation_docs_info: list[CitationDocInfo],
|
||||
db_session: Session,
|
||||
assistant_message: ChatMessage,
|
||||
is_clarification: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save a chat turn by populating the assistant_message and creating related entities.
|
||||
@@ -175,10 +176,12 @@ def save_chat_turn(
|
||||
citation_docs_info: List of citation document information for building citations mapping
|
||||
db_session: Database session for persistence
|
||||
assistant_message: The ChatMessage object to populate (should already exist in DB)
|
||||
is_clarification: Whether this assistant message is a clarification question (deep research flow)
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
assistant_message.reasoning_tokens = reasoning_tokens
|
||||
assistant_message.is_clarification = is_clarification
|
||||
|
||||
# Calculate token count using default tokenizer, when storing, this should not use the LLM
|
||||
# specific one so we use a system default tokenizer here.
|
||||
|
||||
@@ -7,6 +7,7 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
# Redis key prefixes for chat session stop signals
|
||||
PREFIX = "chatsessionstop"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 24 * 60 * 60 # 24 hours - defensive TTL to prevent memory leaks
|
||||
|
||||
|
||||
def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
@@ -24,7 +25,7 @@ def set_fence(chat_session_id: UUID, redis_client: Redis, value: bool) -> None:
|
||||
redis_client.delete(fence_key)
|
||||
return
|
||||
|
||||
redis_client.set(fence_key, 0)
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
|
||||
|
||||
def is_connected(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
|
||||
@@ -24,6 +24,12 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
# Whether to send user metadata (user_id/email and session_id) to the LLM provider.
|
||||
# Disabled by default.
|
||||
SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
os.environ.get("SEND_USER_METADATA_TO_LLM_PROVIDER", "")
|
||||
).lower() == "true"
|
||||
|
||||
#####
|
||||
# User Facing Features Configs
|
||||
#####
|
||||
@@ -31,7 +37,6 @@ BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
# Controls whether users can use User Knowledge (personal documents) in assistants
|
||||
DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true"
|
||||
|
||||
@@ -177,6 +177,7 @@ class DocumentSource(str, Enum):
|
||||
SLAB = "slab"
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
CODA = "coda"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
@@ -596,6 +597,7 @@ DocumentSourceDescription: dict[DocumentSource, str] = {
|
||||
DocumentSource.SLAB: "slab data",
|
||||
DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)",
|
||||
DocumentSource.FILE: "files",
|
||||
DocumentSource.CODA: "coda - team workspace with docs, tables, and pages",
|
||||
DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \
|
||||
project management, and collaboration tools into a single, customizable platform",
|
||||
DocumentSource.ZULIP: "zulip data",
|
||||
|
||||
@@ -65,9 +65,10 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
# Fallback token limit for models where the max context is unknown
|
||||
# Set conservatively at 32K to handle most modern models
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 32000
|
||||
)
|
||||
|
||||
# This is used when computing how much context space is available for documents
|
||||
|
||||
@@ -97,28 +97,31 @@ class AsanaAPI:
|
||||
self, project_gid: str, start_date: str, start_seconds: int
|
||||
) -> Iterator[AsanaTask]:
|
||||
project = self.project_api.get_project(project_gid, opts={})
|
||||
if project["archived"]:
|
||||
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
|
||||
yield from []
|
||||
if not project["team"] or not project["team"]["gid"]:
|
||||
project_name = project.get("name", project_gid)
|
||||
team = project.get("team") or {}
|
||||
team_gid = team.get("gid")
|
||||
|
||||
if project.get("archived"):
|
||||
logger.info(f"Skipping archived project: {project_name} ({project_gid})")
|
||||
return
|
||||
if not team_gid:
|
||||
logger.info(
|
||||
f"Skipping project without a team: {project['name']} ({project_gid})"
|
||||
f"Skipping project without a team: {project_name} ({project_gid})"
|
||||
)
|
||||
yield from []
|
||||
if project["privacy_setting"] == "private":
|
||||
if self.team_gid and project["team"]["gid"] != self.team_gid:
|
||||
return
|
||||
if project.get("privacy_setting") == "private":
|
||||
if self.team_gid and team_gid != self.team_gid:
|
||||
logger.info(
|
||||
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
yield from []
|
||||
else:
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project['name']} ({project_gid})"
|
||||
f"Skipping private project not in configured team: {project_name} ({project_gid})"
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project_name} ({project_gid})"
|
||||
)
|
||||
|
||||
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||
logger.info(
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})"
|
||||
)
|
||||
|
||||
opts = {
|
||||
@@ -157,7 +160,7 @@ class AsanaAPI:
|
||||
link=data["permalink_url"],
|
||||
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||
project_gid=project_gid,
|
||||
project_name=project["name"],
|
||||
project_name=project_name,
|
||||
)
|
||||
yield task
|
||||
except Exception:
|
||||
|
||||
711
backend/onyx/connectors/coda/connector.py
Normal file
711
backend/onyx/connectors/coda/connector.py
Normal file
@@ -0,0 +1,711 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
_CODA_CALL_TIMEOUT = 30
|
||||
_CODA_BASE_URL = "https://coda.io/apis/v1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CodaClientRequestFailedError(ConnectionError):
|
||||
def __init__(self, message: str, status_code: int):
|
||||
super().__init__(
|
||||
f"Coda API request failed with status {status_code}: {message}"
|
||||
)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class CodaDoc(BaseModel):
|
||||
id: str
|
||||
browser_link: str
|
||||
name: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
workspace_id: str
|
||||
workspace_name: str
|
||||
folder_id: str | None
|
||||
folder_name: str | None
|
||||
|
||||
|
||||
class CodaPage(BaseModel):
|
||||
id: str
|
||||
browser_link: str
|
||||
name: str
|
||||
content_type: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaTable(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
browser_link: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaRow(BaseModel):
|
||||
id: str
|
||||
name: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
browser_link: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
values: Dict[str, Any]
|
||||
table_id: str
|
||||
doc_id: str
|
||||
|
||||
|
||||
class CodaApiClient:
|
||||
def __init__(
|
||||
self,
|
||||
bearer_token: str,
|
||||
) -> None:
|
||||
self.bearer_token = bearer_token
|
||||
self.base_url = os.environ.get("CODA_BASE_URL", _CODA_BASE_URL)
|
||||
|
||||
def get(
|
||||
self, endpoint: str, params: Optional[dict[str, str]] = None
|
||||
) -> dict[str, Any]:
|
||||
url = self._build_url(endpoint)
|
||||
headers = self._build_headers()
|
||||
|
||||
response = rl_requests.get(
|
||||
url, headers=headers, params=params, timeout=_CODA_CALL_TIMEOUT
|
||||
)
|
||||
|
||||
try:
|
||||
json = response.json()
|
||||
except Exception:
|
||||
json = {}
|
||||
|
||||
if response.status_code >= 300:
|
||||
error = response.reason
|
||||
response_error = json.get("error", {}).get("message", "")
|
||||
if response_error:
|
||||
error = response_error
|
||||
raise CodaClientRequestFailedError(error, response.status_code)
|
||||
|
||||
return json
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self.bearer_token}"}
|
||||
|
||||
def _build_url(self, endpoint: str) -> str:
|
||||
return self.base_url.rstrip("/") + "/" + endpoint.lstrip("/")
|
||||
|
||||
|
||||
class CodaConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
index_page_content: bool = True,
|
||||
workspace_id: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.index_page_content = index_page_content
|
||||
self.workspace_id = workspace_id
|
||||
self._coda_client: CodaApiClient | None = None
|
||||
|
||||
@property
|
||||
def coda_client(self) -> CodaApiClient:
|
||||
if self._coda_client is None:
|
||||
raise ConnectorMissingCredentialError("Coda")
|
||||
return self._coda_client
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_doc(self, doc_id: str) -> CodaDoc:
|
||||
"""Fetch a specific Coda document by its ID."""
|
||||
logger.debug(f"Fetching Coda doc with ID: {doc_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(f"Failed to fetch doc: {doc_id}") from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaDoc(
|
||||
id=response["id"],
|
||||
browser_link=response["browserLink"],
|
||||
name=response["name"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
workspace_id=response["workspace"]["id"],
|
||||
workspace_name=response["workspace"]["name"],
|
||||
folder_id=response["folder"]["id"] if response.get("folder") else None,
|
||||
folder_name=response["folder"]["name"] if response.get("folder") else None,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_page(self, doc_id: str, page_id: str) -> CodaPage:
|
||||
"""Fetch a specific page from a Coda document."""
|
||||
logger.debug(f"Fetching Coda page with ID: {page_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}/pages/{page_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch page: {page_id} from doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaPage(
|
||||
id=response["id"],
|
||||
doc_id=doc_id,
|
||||
browser_link=response["browserLink"],
|
||||
name=response["name"],
|
||||
content_type=response["contentType"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_table(self, doc_id: str, table_id: str) -> CodaTable:
|
||||
"""Fetch a specific table from a Coda document."""
|
||||
logger.debug(f"Fetching Coda table with ID: {table_id}")
|
||||
try:
|
||||
response = self.coda_client.get(f"docs/{doc_id}/tables/{table_id}")
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch table: {table_id} from doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
return CodaTable(
|
||||
id=response["id"],
|
||||
name=response["name"],
|
||||
browser_link=response["browserLink"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _get_row(self, doc_id: str, table_id: str, row_id: str) -> CodaRow:
|
||||
"""Fetch a specific row from a Coda table."""
|
||||
logger.debug(f"Fetching Coda row with ID: {row_id}")
|
||||
try:
|
||||
response = self.coda_client.get(
|
||||
f"docs/{doc_id}/tables/{table_id}/rows/{row_id}"
|
||||
)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to fetch row: {row_id} from table: {table_id} in doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
values = {}
|
||||
for col_name, col_value in response.get("values", {}).items():
|
||||
values[col_name] = col_value
|
||||
|
||||
return CodaRow(
|
||||
id=response["id"],
|
||||
name=response.get("name"),
|
||||
index=response.get("index"),
|
||||
browser_link=response["browserLink"],
|
||||
created_at=response["createdAt"],
|
||||
updated_at=response["updatedAt"],
|
||||
values=values,
|
||||
table_id=table_id,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_all_docs(
|
||||
self, endpoint: str = "docs", params: Optional[Dict[str, str]] = None
|
||||
) -> List[CodaDoc]:
|
||||
"""List all Coda documents in the workspace."""
|
||||
logger.debug("Listing documents in Coda")
|
||||
|
||||
all_docs: List[CodaDoc] = []
|
||||
next_page_token: str | None = None
|
||||
params = params or {}
|
||||
|
||||
if self.workspace_id:
|
||||
params["workspaceId"] = self.workspace_id
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError("Failed to list docs") from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
|
||||
for item in items:
|
||||
doc = CodaDoc(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
workspace_id=item["workspace"]["id"],
|
||||
workspace_name=item["workspace"]["name"],
|
||||
folder_id=item["folder"]["id"] if item.get("folder") else None,
|
||||
folder_name=item["folder"]["name"] if item.get("folder") else None,
|
||||
)
|
||||
all_docs.append(doc)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(all_docs)} docs")
|
||||
return all_docs
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_pages_in_doc(self, doc_id: str) -> List[CodaPage]:
|
||||
"""List all pages in a Coda document."""
|
||||
logger.debug(f"Listing pages in Coda doc with ID: {doc_id}")
|
||||
|
||||
pages: List[CodaPage] = []
|
||||
endpoint = f"docs/{doc_id}/pages"
|
||||
params: Dict[str, str] = {}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list pages for doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
# can be removed if we don't care to skip hidden pages
|
||||
if item.get("isHidden", False):
|
||||
continue
|
||||
|
||||
pages.append(
|
||||
CodaPage(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
content_type=item["contentType"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(pages)} pages in doc {doc_id}")
|
||||
return pages
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_page_content(self, doc_id: str, page_id: str) -> str:
|
||||
"""Fetch the content of a Coda page."""
|
||||
logger.debug(f"Fetching content for page {page_id} in doc {doc_id}")
|
||||
|
||||
content_parts = []
|
||||
next_page_token: str | None = None
|
||||
params: Dict[str, str] = {}
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(
|
||||
f"docs/{doc_id}/pages/{page_id}/content", params=params
|
||||
)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
logger.debug(f"No content available for page {page_id}")
|
||||
return ""
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
|
||||
for item in items:
|
||||
item_content = item.get("itemContent", {})
|
||||
|
||||
content_text = item_content.get("content", "")
|
||||
if content_text:
|
||||
content_parts.append(content_text)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_tables(self, doc_id: str) -> List[CodaTable]:
|
||||
"""List all tables in a Coda document."""
|
||||
logger.debug(f"Listing tables in Coda doc with ID: {doc_id}")
|
||||
|
||||
tables: List[CodaTable] = []
|
||||
endpoint = f"docs/{doc_id}/tables"
|
||||
params: Dict[str, str] = {}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list tables for doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
tables.append(
|
||||
CodaTable(
|
||||
id=item["id"],
|
||||
browser_link=item["browserLink"],
|
||||
name=item["name"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(tables)} tables in doc {doc_id}")
|
||||
return tables
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _list_rows_and_values(self, doc_id: str, table_id: str) -> List[CodaRow]:
|
||||
"""List all rows and their values in a table."""
|
||||
logger.debug(f"Listing rows in Coda table: {table_id} in Coda doc: {doc_id}")
|
||||
|
||||
rows: List[CodaRow] = []
|
||||
endpoint = f"docs/{doc_id}/tables/{table_id}/rows"
|
||||
params: Dict[str, str] = {"valueFormat": "rich"}
|
||||
next_page_token: str | None = None
|
||||
|
||||
while True:
|
||||
if next_page_token:
|
||||
params["pageToken"] = next_page_token
|
||||
|
||||
try:
|
||||
response = self.coda_client.get(endpoint, params=params)
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Failed to list rows for table: {table_id} in doc: {doc_id}"
|
||||
) from e
|
||||
else:
|
||||
raise
|
||||
|
||||
items = response.get("items", [])
|
||||
for item in items:
|
||||
values = {}
|
||||
for col_name, col_value in item.get("values", {}).items():
|
||||
values[col_name] = col_value
|
||||
|
||||
rows.append(
|
||||
CodaRow(
|
||||
id=item["id"],
|
||||
name=item["name"],
|
||||
index=item["index"],
|
||||
browser_link=item["browserLink"],
|
||||
created_at=item["createdAt"],
|
||||
updated_at=item["updatedAt"],
|
||||
values=values,
|
||||
table_id=table_id,
|
||||
doc_id=doc_id,
|
||||
)
|
||||
)
|
||||
|
||||
next_page_token = response.get("nextPageToken")
|
||||
if not next_page_token:
|
||||
break
|
||||
|
||||
logger.debug(f"Found {len(rows)} rows in table {table_id}")
|
||||
return rows
|
||||
|
||||
def _convert_page_to_document(self, page: CodaPage, content: str = "") -> Document:
|
||||
"""Convert a page into a Document."""
|
||||
page_updated = datetime.fromisoformat(page.updated_at).astimezone(timezone.utc)
|
||||
|
||||
text_parts = [page.name, page.browser_link]
|
||||
if content:
|
||||
text_parts.append(content)
|
||||
|
||||
sections = [TextSection(link=page.browser_link, text="\n\n".join(text_parts))]
|
||||
|
||||
return Document(
|
||||
id=f"coda-page-{page.doc_id}-{page.id}",
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.CODA,
|
||||
semantic_identifier=page.name or f"Page {page.id}",
|
||||
doc_updated_at=page_updated,
|
||||
metadata={
|
||||
"browser_link": page.browser_link,
|
||||
"doc_id": page.doc_id,
|
||||
"content_type": page.content_type,
|
||||
},
|
||||
)
|
||||
|
||||
def _convert_table_with_rows_to_document(
|
||||
self, table: CodaTable, rows: List[CodaRow]
|
||||
) -> Document:
|
||||
"""Convert a table and its rows into a single Document with multiple sections (one per row)."""
|
||||
table_updated = datetime.fromisoformat(table.updated_at).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
sections: List[TextSection] = []
|
||||
for row in rows:
|
||||
content_text = " ".join(
|
||||
str(v) if not isinstance(v, list) else " ".join(map(str, v))
|
||||
for v in row.values.values()
|
||||
)
|
||||
|
||||
row_name = row.name or f"Row {row.index or row.id}"
|
||||
text = f"{row_name}: {content_text}" if content_text else row_name
|
||||
|
||||
sections.append(TextSection(link=row.browser_link, text=text))
|
||||
|
||||
# If no rows, create a single section for the table itself
|
||||
if not sections:
|
||||
sections = [
|
||||
TextSection(link=table.browser_link, text=f"Table: {table.name}")
|
||||
]
|
||||
|
||||
return Document(
|
||||
id=f"coda-table-{table.doc_id}-{table.id}",
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.CODA,
|
||||
semantic_identifier=table.name or f"Table {table.id}",
|
||||
doc_updated_at=table_updated,
|
||||
metadata={
|
||||
"browser_link": table.browser_link,
|
||||
"doc_id": table.doc_id,
|
||||
"row_count": str(len(rows)),
|
||||
},
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load and validate Coda credentials."""
|
||||
self._coda_client = CodaApiClient(bearer_token=credentials["coda_bearer_token"])
|
||||
|
||||
try:
|
||||
self._coda_client.get("docs", params={"limit": "1"})
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 401:
|
||||
raise ConnectorMissingCredentialError("Invalid Coda API token")
|
||||
raise
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""Load all documents from Coda workspace."""
|
||||
|
||||
def _iter_documents() -> Generator[Document, None, None]:
|
||||
docs = self._list_all_docs()
|
||||
logger.info(f"Found {len(docs)} Coda docs to process")
|
||||
|
||||
for doc in docs:
|
||||
logger.debug(f"Processing doc: {doc.name} ({doc.id})")
|
||||
|
||||
try:
|
||||
pages = self._list_pages_in_doc(doc.id)
|
||||
for page in pages:
|
||||
content = ""
|
||||
if self.index_page_content:
|
||||
try:
|
||||
content = self._fetch_page_content(doc.id, page.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch content for page {page.id}: {e}"
|
||||
)
|
||||
yield self._convert_page_to_document(page, content)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list pages for doc {doc.id}: {e}")
|
||||
|
||||
try:
|
||||
tables = self._list_tables(doc.id)
|
||||
for table in tables:
|
||||
try:
|
||||
rows = self._list_rows_and_values(doc.id, table.id)
|
||||
yield self._convert_table_with_rows_to_document(table, rows)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(
|
||||
f"Failed to list rows for table {table.id}: {e}"
|
||||
)
|
||||
yield self._convert_table_with_rows_to_document(table, [])
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list tables for doc {doc.id}: {e}")
|
||||
|
||||
return batch_generator(_iter_documents(), self.batch_size)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Polls the Coda API for documents updated between start and end timestamps.
|
||||
We refer to page and table update times to determine if they need to be re-indexed.
|
||||
"""
|
||||
|
||||
def _iter_documents() -> Generator[Document, None, None]:
|
||||
docs = self._list_all_docs()
|
||||
logger.info(
|
||||
f"Polling {len(docs)} Coda docs for updates between {start} and {end}"
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
try:
|
||||
pages = self._list_pages_in_doc(doc.id)
|
||||
for page in pages:
|
||||
page_timestamp = (
|
||||
datetime.fromisoformat(page.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
if start < page_timestamp <= end:
|
||||
content = ""
|
||||
if self.index_page_content:
|
||||
try:
|
||||
content = self._fetch_page_content(doc.id, page.id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch content for page {page.id}: {e}"
|
||||
)
|
||||
yield self._convert_page_to_document(page, content)
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list pages for doc {doc.id}: {e}")
|
||||
|
||||
try:
|
||||
tables = self._list_tables(doc.id)
|
||||
for table in tables:
|
||||
table_timestamp = (
|
||||
datetime.fromisoformat(table.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
|
||||
try:
|
||||
rows = self._list_rows_and_values(doc.id, table.id)
|
||||
|
||||
table_or_rows_updated = start < table_timestamp <= end
|
||||
if not table_or_rows_updated:
|
||||
for row in rows:
|
||||
row_timestamp = (
|
||||
datetime.fromisoformat(row.updated_at)
|
||||
.astimezone(timezone.utc)
|
||||
.timestamp()
|
||||
)
|
||||
if start < row_timestamp <= end:
|
||||
table_or_rows_updated = True
|
||||
break
|
||||
|
||||
if table_or_rows_updated:
|
||||
yield self._convert_table_with_rows_to_document(
|
||||
table, rows
|
||||
)
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(
|
||||
f"Failed to list rows for table {table.id}: {e}"
|
||||
)
|
||||
if table_timestamp > start and table_timestamp <= end:
|
||||
yield self._convert_table_with_rows_to_document(
|
||||
table, []
|
||||
)
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
logger.warning(f"Failed to list tables for doc {doc.id}: {e}")
|
||||
|
||||
return batch_generator(_iter_documents(), self.batch_size)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validates the Coda connector settings calling the 'whoami' endpoint."""
|
||||
try:
|
||||
response = self.coda_client.get("whoami")
|
||||
logger.info(
|
||||
f"Coda connector validated for user: {response.get('name', 'Unknown')}"
|
||||
)
|
||||
|
||||
if self.workspace_id:
|
||||
params = {"workspaceId": self.workspace_id, "limit": "1"}
|
||||
self.coda_client.get("docs", params=params)
|
||||
logger.info(f"Validated access to workspace: {self.workspace_id}")
|
||||
|
||||
except CodaClientRequestFailedError as e:
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Coda credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
"Coda workspace not found or not accessible (HTTP 404). "
|
||||
"Please verify the workspace_id is correct and shared with the integration."
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Coda rate-limits being exceeded (HTTP 429). "
|
||||
"Please try again later."
|
||||
)
|
||||
else:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected Coda HTTP error (status={e.status_code}): {e}"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise UnexpectedValidationError(
|
||||
f"Unexpected error during Coda settings validation: {exc}"
|
||||
)
|
||||
@@ -387,124 +387,162 @@ class ConfluenceConnector(
|
||||
attachment_docs: list[Document] = []
|
||||
page_url = ""
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
if not self.allow_images:
|
||||
if media_type.startswith("image/"):
|
||||
logger.info(
|
||||
f"Skipping attachment because allow images is False: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
try:
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping attachment because it is not an accepted file type: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
# Attachment document id: use the download URL for stable identity
|
||||
try:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"], self.is_cloud
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Invalid attachment url for id {attachment['id']}, skipping"
|
||||
)
|
||||
logger.debug(f"Error building attachment url: {e}")
|
||||
continue
|
||||
try:
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
if not self.allow_images:
|
||||
if media_type.startswith("image/"):
|
||||
logger.info(
|
||||
f"Skipping attachment because allow images is False: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping attachment because it is not an accepted file type: {attachment['title']}"
|
||||
)
|
||||
continue
|
||||
|
||||
content_text, file_storage_name = response
|
||||
logger.info(
|
||||
f"Processing attachment: {attachment['title']} attached to page {page['title']}"
|
||||
)
|
||||
# Attachment document id: use the download URL for stable identity
|
||||
try:
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["download"], self.is_cloud
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Invalid attachment url for id {attachment['id']}, skipping"
|
||||
)
|
||||
logger.debug(f"Error building attachment url: {e}")
|
||||
continue
|
||||
try:
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
allow_images=self.allow_images,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
sections.append(
|
||||
ImageSection(link=object_url, image_file_id=file_storage_name)
|
||||
content_text, file_storage_name = response
|
||||
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
if content_text:
|
||||
sections.append(TextSection(text=content_text, link=object_url))
|
||||
elif file_storage_name:
|
||||
sections.append(
|
||||
ImageSection(
|
||||
link=object_url, image_file_id=file_storage_name
|
||||
)
|
||||
)
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
if "space" in attachment:
|
||||
attachment_metadata["space"] = attachment["space"].get(
|
||||
"name", ""
|
||||
)
|
||||
labels: list[str] = []
|
||||
if "metadata" in attachment and "labels" in attachment["metadata"]:
|
||||
for label in attachment["metadata"]["labels"].get(
|
||||
"results", []
|
||||
):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
attachment_metadata["labels"] = labels
|
||||
page_url = page_url or build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
attachment_metadata["parent_page_id"] = page_url
|
||||
attachment_id = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build attachment-specific metadata
|
||||
attachment_metadata: dict[str, str | list[str]] = {}
|
||||
if "space" in attachment:
|
||||
attachment_metadata["space"] = attachment["space"].get("name", "")
|
||||
labels: list[str] = []
|
||||
if "metadata" in attachment and "labels" in attachment["metadata"]:
|
||||
for label in attachment["metadata"]["labels"].get("results", []):
|
||||
labels.append(label.get("name", ""))
|
||||
if labels:
|
||||
attachment_metadata["labels"] = labels
|
||||
page_url = page_url or build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
attachment_metadata["parent_page_id"] = page_url
|
||||
attachment_id = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners = [
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
display_name = author.get("displayName", "Unknown")
|
||||
email = author.get("email", "unknown@domain.invalid")
|
||||
primary_owners = [
|
||||
BasicExpertInfo(display_name=display_name, email=email)
|
||||
]
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
if attachment.get("version")
|
||||
and attachment["version"].get("when")
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if is_atlassian_date_error(e):
|
||||
# propagate error to be caught and retried
|
||||
raise
|
||||
attachment_failures.append(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=object_url,
|
||||
document_link=object_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
except HTTPError as e:
|
||||
# If we get a 403 after all retries, the user likely doesn't have permission
|
||||
# to access attachments on this page. Log and skip rather than failing the whole job.
|
||||
if e.response and e.response.status_code == 403:
|
||||
page_title = page.get("title", "unknown")
|
||||
page_id = page.get("id", "unknown")
|
||||
logger.warning(
|
||||
f"Permission denied (403) when fetching attachments for page '{page_title}' "
|
||||
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
|
||||
"Skipping attachments for this page."
|
||||
)
|
||||
# Build the page URL for the failure record
|
||||
try:
|
||||
page_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
except Exception:
|
||||
page_url = f"page_id:{page_id}"
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
metadata=attachment_metadata,
|
||||
doc_updated_at=(
|
||||
datetime_from_string(attachment["version"]["when"])
|
||||
if attachment.get("version")
|
||||
and attachment["version"].get("when")
|
||||
else None
|
||||
),
|
||||
primary_owners=primary_owners,
|
||||
)
|
||||
attachment_docs.append(attachment_doc)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if is_atlassian_date_error(e):
|
||||
# propagate error to be caught and retried
|
||||
raise
|
||||
attachment_failures.append(
|
||||
return [], [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=object_url,
|
||||
document_link=object_url,
|
||||
document_id=page_id,
|
||||
document_link=page_url,
|
||||
),
|
||||
failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}",
|
||||
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
|
||||
exception=e,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise
|
||||
|
||||
return attachment_docs, attachment_failures
|
||||
|
||||
|
||||
@@ -579,13 +579,18 @@ class OnyxConfluence:
|
||||
while url_suffix:
|
||||
logger.debug(f"Making confluence call to {url_suffix}")
|
||||
try:
|
||||
# Only pass params if they're not already in the URL to avoid duplicate
|
||||
# params accumulating. Confluence's _links.next already includes these.
|
||||
params = {}
|
||||
if "body-format=" not in url_suffix:
|
||||
params["body-format"] = "atlas_doc_format"
|
||||
if "expand=" not in url_suffix:
|
||||
params["expand"] = "body.atlas_doc_format"
|
||||
|
||||
raw_response = self.get(
|
||||
path=url_suffix,
|
||||
advanced_mode=True,
|
||||
params={
|
||||
"body-format": "atlas_doc_format",
|
||||
"expand": "body.atlas_doc_format",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in confluence call to {url_suffix}")
|
||||
|
||||
@@ -134,7 +134,7 @@ def process_onyx_metadata(
|
||||
metadata: dict[str, Any],
|
||||
) -> tuple[OnyxMetadata, dict[str, Any]]:
|
||||
"""
|
||||
Users may set Onyx metadata and custom tags in text files. https://docs.onyx.app/admin/connectors/official/file
|
||||
Users may set Onyx metadata and custom tags in text files. https://docs.onyx.app/admins/connectors/official/file
|
||||
Any unrecognized fields are treated as custom tags.
|
||||
"""
|
||||
p_owner_names = metadata.get("primary_owners")
|
||||
|
||||
@@ -155,7 +155,7 @@ def _process_file(
|
||||
content_type=file_type,
|
||||
)
|
||||
|
||||
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/admin/connectors/official/file
|
||||
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/admins/connectors/official/file
|
||||
# If so, we should add it to any metadata processed so far
|
||||
if extraction_result.metadata:
|
||||
logger.debug(
|
||||
|
||||
@@ -44,7 +44,7 @@ USER_FIELDS = "nextPageToken, users(primaryEmail)"
|
||||
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
|
||||
|
||||
# Documentation and error messages
|
||||
SCOPE_DOC_URL = "https://docs.onyx.app/admin/connectors/official/google_drive/overview"
|
||||
SCOPE_DOC_URL = "https://docs.onyx.app/admins/connectors/official/google_drive/overview"
|
||||
ONYX_SCOPE_INSTRUCTIONS = (
|
||||
"You have upgraded Onyx without updating the Google Auth scopes. "
|
||||
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.utils.logger import setup_logger
|
||||
HUBSPOT_BASE_URL = "https://app.hubspot.com"
|
||||
HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me"
|
||||
|
||||
# Available HubSpot object types
|
||||
AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"}
|
||||
|
||||
HUBSPOT_PAGE_SIZE = 100
|
||||
|
||||
@@ -68,6 +68,10 @@ CONNECTOR_CLASS_MAP = {
|
||||
module_path="onyx.connectors.slab.connector",
|
||||
class_name="SlabConnector",
|
||||
),
|
||||
DocumentSource.CODA: ConnectorMapping(
|
||||
module_path="onyx.connectors.coda.connector",
|
||||
class_name="CodaConnector",
|
||||
),
|
||||
DocumentSource.NOTION: ConnectorMapping(
|
||||
module_path="onyx.connectors.notion.connector",
|
||||
class_name="NotionConnector",
|
||||
|
||||
@@ -13,6 +13,7 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import msal # type: ignore[import-untyped]
|
||||
import requests
|
||||
@@ -727,46 +728,77 @@ class SharepointConnector(
|
||||
|
||||
return self._graph_client
|
||||
|
||||
@staticmethod
|
||||
def _strip_share_link_tokens(path: str) -> list[str]:
|
||||
# Share links often include a token prefix like /:f:/r/ or /:x:/r/.
|
||||
segments = [segment for segment in path.split("/") if segment]
|
||||
if segments and segments[0].startswith(":"):
|
||||
segments = segments[1:]
|
||||
if segments and segments[0] in {"r", "s", "g"}:
|
||||
segments = segments[1:]
|
||||
return segments
|
||||
|
||||
@staticmethod
|
||||
def _normalize_sharepoint_url(url: str) -> tuple[str | None, list[str]]:
|
||||
try:
|
||||
parsed = urlsplit(url)
|
||||
except ValueError:
|
||||
logger.warning(f"Sharepoint URL '{url}' could not be parsed")
|
||||
return None, []
|
||||
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
logger.warning(
|
||||
f"Sharepoint URL '{url}' is not a valid absolute URL (missing scheme or host)"
|
||||
)
|
||||
return None, []
|
||||
|
||||
path_segments = SharepointConnector._strip_share_link_tokens(parsed.path)
|
||||
return f"{parsed.scheme}://{parsed.netloc}", path_segments
|
||||
|
||||
@staticmethod
|
||||
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
base_url, parts = SharepointConnector._normalize_sharepoint_url(url.strip())
|
||||
if base_url is None:
|
||||
continue
|
||||
|
||||
lower_parts = [part.lower() for part in parts]
|
||||
site_type_index = None
|
||||
if "sites" in parts:
|
||||
site_type_index = parts.index("sites")
|
||||
elif "teams" in parts:
|
||||
site_type_index = parts.index("teams")
|
||||
for site_token in ("sites", "teams"):
|
||||
if site_token in lower_parts:
|
||||
site_type_index = lower_parts.index(site_token)
|
||||
break
|
||||
|
||||
if site_type_index is not None:
|
||||
# Extract the base site URL (up to and including the site/team name)
|
||||
site_url = "/".join(parts[: site_type_index + 2])
|
||||
remaining_parts = parts[site_type_index + 2 :]
|
||||
if site_type_index is None or len(parts) <= site_type_index + 1:
|
||||
logger.warning(
|
||||
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/<name> or /teams/<name>)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
drive_name = unquote(remaining_parts[0])
|
||||
folder_path = (
|
||||
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||
if len(remaining_parts) > 1
|
||||
else None
|
||||
)
|
||||
else:
|
||||
drive_name = None
|
||||
folder_path = None
|
||||
site_path = parts[: site_type_index + 2]
|
||||
remaining_parts = parts[site_type_index + 2 :]
|
||||
site_url = f"{base_url}/" + "/".join(site_path)
|
||||
|
||||
site_data_list.append(
|
||||
SiteDescriptor(
|
||||
url=site_url,
|
||||
drive_name=drive_name,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
drive_name = unquote(remaining_parts[0])
|
||||
folder_path = (
|
||||
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||
if len(remaining_parts) > 1
|
||||
else None
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)"
|
||||
drive_name = None
|
||||
folder_path = None
|
||||
|
||||
site_data_list.append(
|
||||
SiteDescriptor(
|
||||
url=site_url,
|
||||
drive_name=drive_name,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _get_drive_items_for_drive_name(
|
||||
|
||||
@@ -99,7 +99,9 @@ DEFAULT_HEADERS = {
|
||||
"image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7"
|
||||
),
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
"Accept-Encoding": "gzip, deflate, br",
|
||||
# Brotli decoding has been flaky in brotlicffi/httpx for certain chunked responses;
|
||||
# stick to gzip/deflate to keep connectivity checks stable.
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
"Upgrade-Insecure-Requests": "1",
|
||||
"Sec-Fetch-Dest": "document",
|
||||
@@ -349,10 +351,13 @@ def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
|
||||
|
||||
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
|
||||
# requests should handle brotli compression automatically
|
||||
# as long as the brotli package is available in the venv. Leaving this line here to avoid
|
||||
# a regression as someone says "Ah, looks like this brotli package isn't used anywhere, let's remove it"
|
||||
# import brotli
|
||||
try:
|
||||
response = requests.get(sitemap_url, headers=DEFAULT_HEADERS)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
|
||||
@@ -20,6 +20,11 @@ class OptionalSearchSetting(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
|
||||
|
||||
class ChannelMetadata(TypedDict):
|
||||
"""Type definition for cached channel metadata."""
|
||||
|
||||
name: str
|
||||
type: ChannelType
|
||||
is_private: bool
|
||||
is_member: bool
|
||||
|
||||
|
||||
class SlackMessage(BaseModel):
|
||||
document_id: str
|
||||
|
||||
@@ -3,7 +3,10 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import ValidationError
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
@@ -13,11 +16,11 @@ from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.federated.models import SlackMessage
|
||||
from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES
|
||||
from onyx.context.search.federated.slack_search_utils import build_channel_query_filter
|
||||
from onyx.context.search.federated.slack_search_utils import build_slack_queries
|
||||
from onyx.context.search.federated.slack_search_utils import ChannelTypeString
|
||||
from onyx.context.search.federated.slack_search_utils import get_channel_type
|
||||
from onyx.context.search.federated.slack_search_utils import (
|
||||
get_channel_type_for_missing_scope,
|
||||
@@ -52,6 +55,7 @@ HIGHLIGHT_START_CHAR = "\ue000"
|
||||
HIGHLIGHT_END_CHAR = "\ue001"
|
||||
|
||||
CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours
|
||||
SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include
|
||||
CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching
|
||||
CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff)
|
||||
@@ -59,7 +63,7 @@ CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential
|
||||
|
||||
def fetch_and_cache_channel_metadata(
|
||||
access_token: str, team_id: str, include_private: bool = True
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
) -> dict[str, ChannelMetadata]:
|
||||
"""
|
||||
Fetch ALL channel metadata in one API call and cache it.
|
||||
|
||||
@@ -77,28 +81,28 @@ def fetch_and_cache_channel_metadata(
|
||||
try:
|
||||
cached = redis_client.get(cache_key)
|
||||
if cached:
|
||||
logger.info(f"Channel metadata cache HIT for team {team_id}")
|
||||
logger.debug(f"Channel metadata cache HIT for team {team_id}")
|
||||
cached_str: str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
cached_data: dict[str, dict[str, Any]] = json.loads(cached_str)
|
||||
logger.info(f"Loaded {len(cached_data)} channels from cache")
|
||||
cached_data = cast(dict[str, ChannelMetadata], json.loads(cached_str))
|
||||
logger.debug(f"Loaded {len(cached_data)} channels from cache")
|
||||
if not include_private:
|
||||
filtered = {
|
||||
filtered: dict[str, ChannelMetadata] = {
|
||||
k: v
|
||||
for k, v in cached_data.items()
|
||||
if v.get("type") != "private_channel"
|
||||
if v.get("type") != ChannelType.PRIVATE_CHANNEL.value
|
||||
}
|
||||
logger.info(f"Filtered to {len(filtered)} channels (exclude private)")
|
||||
logger.debug(f"Filtered to {len(filtered)} channels (exclude private)")
|
||||
return filtered
|
||||
return cached_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading from channel metadata cache: {e}")
|
||||
|
||||
# Cache miss - fetch from Slack API with retry logic
|
||||
logger.info(f"Channel metadata cache MISS for team {team_id} - fetching from API")
|
||||
logger.debug(f"Channel metadata cache MISS for team {team_id} - fetching from API")
|
||||
slack_client = WebClient(token=access_token)
|
||||
channel_metadata: dict[str, dict[str, Any]] = {}
|
||||
channel_metadata: dict[str, ChannelMetadata] = {}
|
||||
|
||||
# Retry logic with exponential backoff
|
||||
last_exception = None
|
||||
@@ -130,7 +134,7 @@ def fetch_and_cache_channel_metadata(
|
||||
|
||||
# Determine channel type
|
||||
channel_type_enum = get_channel_type(channel_info=ch)
|
||||
channel_type = channel_type_enum.value
|
||||
channel_type = ChannelType(channel_type_enum.value)
|
||||
|
||||
channel_metadata[channel_id] = {
|
||||
"name": ch.get("name", ""),
|
||||
@@ -237,9 +241,93 @@ def get_available_channels(
|
||||
return [meta["name"] for meta in metadata.values() if meta["name"]]
|
||||
|
||||
|
||||
def get_cached_user_profile(
|
||||
access_token: str, team_id: str, user_id: str
|
||||
) -> str | None:
|
||||
"""
|
||||
Get a user's display name from cache or fetch from Slack API.
|
||||
|
||||
Uses Redis caching to avoid repeated API calls and rate limiting.
|
||||
Returns the user's real_name or email, or None if not found.
|
||||
"""
|
||||
redis_client = get_redis_client()
|
||||
cache_key = f"slack_federated_search:{team_id}:user:{user_id}"
|
||||
|
||||
# Check cache first
|
||||
try:
|
||||
cached = redis_client.get(cache_key)
|
||||
if cached is not None:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
# Empty string means user was not found previously
|
||||
return cached_str if cached_str else None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading user profile cache: {e}")
|
||||
|
||||
# Cache miss - fetch from Slack API
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
response = slack_client.users_profile_get(user=user_id)
|
||||
response.validate()
|
||||
profile: dict[str, Any] = response.get("profile", {})
|
||||
name: str | None = profile.get("real_name") or profile.get("email")
|
||||
|
||||
# Cache the result (empty string for not found)
|
||||
try:
|
||||
redis_client.set(
|
||||
cache_key,
|
||||
name or "",
|
||||
ex=USER_PROFILE_CACHE_TTL,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error caching user profile: {e}")
|
||||
|
||||
return name
|
||||
|
||||
except SlackApiError as e:
|
||||
error_str = str(e)
|
||||
if "user_not_found" in error_str:
|
||||
logger.debug(
|
||||
f"User {user_id} not found in Slack workspace (likely deleted/deactivated)"
|
||||
)
|
||||
elif "ratelimited" in error_str:
|
||||
# Don't cache rate limit errors - we'll retry later
|
||||
logger.debug(f"Rate limited fetching user {user_id}, will retry later")
|
||||
return None
|
||||
else:
|
||||
logger.warning(f"Could not fetch profile for user {user_id}: {e}")
|
||||
|
||||
# Cache negative result to avoid repeated lookups for missing users
|
||||
try:
|
||||
redis_client.set(cache_key, "", ex=USER_PROFILE_CACHE_TTL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def batch_get_user_profiles(
|
||||
access_token: str, team_id: str, user_ids: set[str]
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Batch fetch user profiles with caching.
|
||||
|
||||
Returns a dict mapping user_id -> display_name for users that were found.
|
||||
"""
|
||||
result: dict[str, str] = {}
|
||||
|
||||
for user_id in user_ids:
|
||||
name = get_cached_user_profile(access_token, team_id, user_id)
|
||||
if name:
|
||||
result[user_id] = name
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_channel_data_from_entities(
|
||||
entities: dict[str, Any] | None,
|
||||
channel_metadata_dict: dict[str, dict[str, Any]] | None,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None,
|
||||
) -> list[str] | None:
|
||||
"""Extract available channels list from metadata based on entity configuration.
|
||||
|
||||
@@ -264,7 +352,7 @@ def _extract_channel_data_from_entities(
|
||||
if meta["name"]
|
||||
and (
|
||||
parsed_entities.include_private_channels
|
||||
or meta.get("type") != ChannelTypeString.PRIVATE_CHANNEL.value
|
||||
or meta.get("type") != ChannelType.PRIVATE_CHANNEL.value
|
||||
)
|
||||
]
|
||||
except ValidationError:
|
||||
@@ -279,10 +367,28 @@ def _should_skip_channel(
|
||||
bot_token: str | None,
|
||||
access_token: str,
|
||||
include_dm: bool,
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> bool:
|
||||
"""Bot context filtering: skip private channels unless explicitly allowed."""
|
||||
"""Bot context filtering: skip private channels unless explicitly allowed.
|
||||
|
||||
Uses pre-fetched channel metadata when available to avoid API calls.
|
||||
"""
|
||||
if bot_token and not include_dm:
|
||||
try:
|
||||
# First try to use pre-fetched metadata from cache
|
||||
if channel_metadata_dict and channel_id in channel_metadata_dict:
|
||||
channel_meta = channel_metadata_dict[channel_id]
|
||||
channel_type_str = channel_meta.get("type", "")
|
||||
is_private_or_dm = channel_type_str in [
|
||||
ChannelType.PRIVATE_CHANNEL.value,
|
||||
ChannelType.IM.value,
|
||||
ChannelType.MPIM.value,
|
||||
]
|
||||
if is_private_or_dm and channel_id != allowed_private_channel:
|
||||
return True
|
||||
return False
|
||||
|
||||
# Fallback: API call only if not in cache (should be rare)
|
||||
token_to_use = bot_token or access_token
|
||||
channel_client = WebClient(token=token_to_use)
|
||||
channel_info = channel_client.conversations_info(channel=channel_id)
|
||||
@@ -306,6 +412,15 @@ def _should_skip_channel(
|
||||
return False
|
||||
|
||||
|
||||
class SlackQueryResult(BaseModel):
|
||||
"""Result from a single Slack query including stats."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
messages: list[SlackMessage]
|
||||
filtered_channels: list[str] # Channels filtered out during this query
|
||||
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
original_query: SearchQuery,
|
||||
@@ -316,7 +431,8 @@ def query_slack(
|
||||
include_dm: bool = False,
|
||||
entities: dict[str, Any] | None = None,
|
||||
available_channels: list[str] | None = None,
|
||||
) -> list[SlackMessage]:
|
||||
channel_metadata_dict: dict[str, ChannelMetadata] | None = None,
|
||||
) -> SlackQueryResult:
|
||||
|
||||
# Check if query has channel override (user specified channels in query)
|
||||
has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__")
|
||||
@@ -370,11 +486,11 @@ def query_slack(
|
||||
# Log token type prefix
|
||||
token_prefix = access_token[:4] if len(access_token) >= 4 else "unknown"
|
||||
logger.error(f"TOKEN TYPE ERROR: access_token type: {token_prefix}...")
|
||||
return []
|
||||
return SlackQueryResult(messages=[], filtered_channels=[])
|
||||
|
||||
# convert matches to slack messages
|
||||
slack_messages: list[SlackMessage] = []
|
||||
filtered_count = 0
|
||||
filtered_channels: list[str] = []
|
||||
for match in matches:
|
||||
text: str | None = match.get("text")
|
||||
permalink: str | None = match.get("permalink")
|
||||
@@ -402,9 +518,14 @@ def query_slack(
|
||||
|
||||
# Apply channel filtering if needed
|
||||
if _should_skip_channel(
|
||||
channel_id, allowed_private_channel, bot_token, access_token, include_dm
|
||||
channel_id,
|
||||
allowed_private_channel,
|
||||
bot_token,
|
||||
access_token,
|
||||
include_dm,
|
||||
channel_metadata_dict,
|
||||
):
|
||||
filtered_count += 1
|
||||
filtered_channels.append(f"{channel_name}({channel_id})")
|
||||
continue
|
||||
|
||||
# generate thread id and document id
|
||||
@@ -459,22 +580,28 @@ def query_slack(
|
||||
)
|
||||
)
|
||||
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"Channel filtering applied: {filtered_count} messages filtered out, {len(slack_messages)} messages kept"
|
||||
)
|
||||
|
||||
return slack_messages
|
||||
return SlackQueryResult(
|
||||
messages=slack_messages, filtered_channels=filtered_channels
|
||||
)
|
||||
|
||||
|
||||
def merge_slack_messages(
|
||||
slack_messages: list[list[SlackMessage]],
|
||||
) -> tuple[list[SlackMessage], dict[str, SlackMessage]]:
|
||||
query_results: list[SlackQueryResult],
|
||||
) -> tuple[list[SlackMessage], dict[str, SlackMessage], set[str]]:
|
||||
"""Merge messages from multiple query results, deduplicating by document_id.
|
||||
|
||||
Returns:
|
||||
Tuple of (merged_messages, docid_to_message, all_filtered_channels)
|
||||
"""
|
||||
merged_messages: list[SlackMessage] = []
|
||||
docid_to_message: dict[str, SlackMessage] = {}
|
||||
all_filtered_channels: set[str] = set()
|
||||
|
||||
for messages in slack_messages:
|
||||
for message in messages:
|
||||
for result in query_results:
|
||||
# Collect filtered channels from all queries
|
||||
all_filtered_channels.update(result.filtered_channels)
|
||||
|
||||
for message in result.messages:
|
||||
if message.document_id in docid_to_message:
|
||||
# update the score and highlighted texts, rest should be identical
|
||||
docid_to_message[message.document_id].slack_score = max(
|
||||
@@ -493,10 +620,12 @@ def merge_slack_messages(
|
||||
# re-sort by score
|
||||
merged_messages.sort(key=lambda x: x.slack_score, reverse=True)
|
||||
|
||||
return merged_messages, docid_to_message
|
||||
return merged_messages, docid_to_message, all_filtered_channels
|
||||
|
||||
|
||||
def get_contextualized_thread_text(message: SlackMessage, access_token: str) -> str:
|
||||
def get_contextualized_thread_text(
|
||||
message: SlackMessage, access_token: str, team_id: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Retrieves the initial thread message as well as the text following the message
|
||||
and combines them into a single string. If the slack query fails, returns the
|
||||
@@ -505,6 +634,11 @@ def get_contextualized_thread_text(message: SlackMessage, access_token: str) ->
|
||||
The idea is that the message (the one that actually matched the search), the
|
||||
initial thread message, and the replies to the message are important in answering
|
||||
the user's query.
|
||||
|
||||
Args:
|
||||
message: The SlackMessage to get context for
|
||||
access_token: Slack OAuth access token
|
||||
team_id: Slack team ID for caching user profiles (optional but recommended)
|
||||
"""
|
||||
channel_id = message.channel_id
|
||||
thread_id = message.thread_id
|
||||
@@ -582,26 +716,33 @@ def get_contextualized_thread_text(message: SlackMessage, access_token: str) ->
|
||||
thread_text += "\n..."
|
||||
break
|
||||
|
||||
# replace user ids with names in the thread text
|
||||
# replace user ids with names in the thread text using cached lookups
|
||||
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
|
||||
for userid in userids:
|
||||
try:
|
||||
response = slack_client.users_profile_get(user=userid)
|
||||
response.validate()
|
||||
profile: dict[str, Any] = response.get("profile", {})
|
||||
name: str | None = profile.get("real_name") or profile.get("email")
|
||||
except SlackApiError as e:
|
||||
# user_not_found is common for deleted users, bots, etc. - not critical
|
||||
if "user_not_found" in str(e):
|
||||
logger.debug(
|
||||
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Could not fetch profile for user {userid}: {e}")
|
||||
continue
|
||||
if not name:
|
||||
continue
|
||||
thread_text = thread_text.replace(f"<@{userid}>", name)
|
||||
|
||||
if team_id:
|
||||
# Use cached batch lookup when team_id is available
|
||||
user_profiles = batch_get_user_profiles(access_token, team_id, userids)
|
||||
for userid, name in user_profiles.items():
|
||||
thread_text = thread_text.replace(f"<@{userid}>", name)
|
||||
else:
|
||||
# Fallback to individual lookups (no caching) when team_id not available
|
||||
for userid in userids:
|
||||
try:
|
||||
response = slack_client.users_profile_get(user=userid)
|
||||
response.validate()
|
||||
profile: dict[str, Any] = response.get("profile", {})
|
||||
user_name: str | None = profile.get("real_name") or profile.get("email")
|
||||
except SlackApiError as e:
|
||||
if "user_not_found" in str(e):
|
||||
logger.debug(
|
||||
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Could not fetch profile for user {userid}: {e}")
|
||||
continue
|
||||
if not user_name:
|
||||
continue
|
||||
thread_text = thread_text.replace(f"<@{userid}>", user_name)
|
||||
|
||||
return thread_text
|
||||
|
||||
@@ -654,9 +795,9 @@ def slack_retrieval(
|
||||
entities = entities or {}
|
||||
|
||||
if not entities:
|
||||
logger.info("No entity configuration found, using defaults")
|
||||
logger.debug("No entity configuration found, using defaults")
|
||||
else:
|
||||
logger.info(f"Using entity configuration: {entities}")
|
||||
logger.debug(f"Using entity configuration: {entities}")
|
||||
|
||||
# Extract limit from entity config if not explicitly provided
|
||||
query_limit = limit
|
||||
@@ -665,7 +806,7 @@ def slack_retrieval(
|
||||
parsed_entities = SlackEntities(**entities)
|
||||
if limit is None:
|
||||
query_limit = parsed_entities.max_messages_per_query
|
||||
logger.info(f"Using max_messages_per_query from config: {query_limit}")
|
||||
logger.debug(f"Using max_messages_per_query from config: {query_limit}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing entities for limit: {e}")
|
||||
if limit is None:
|
||||
@@ -703,7 +844,7 @@ def slack_retrieval(
|
||||
include_dm = True
|
||||
if channel_type == ChannelType.PRIVATE_CHANNEL:
|
||||
allowed_private_channel = slack_event_context.channel_id
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Private channel context: will only allow messages from {allowed_private_channel} + public channels"
|
||||
)
|
||||
|
||||
@@ -721,6 +862,7 @@ def slack_retrieval(
|
||||
include_dm,
|
||||
entities,
|
||||
available_channels,
|
||||
channel_metadata_dict,
|
||||
),
|
||||
)
|
||||
for query_string in query_strings
|
||||
@@ -757,6 +899,7 @@ def slack_retrieval(
|
||||
include_dm,
|
||||
dm_entities,
|
||||
available_channels,
|
||||
channel_metadata_dict,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -764,19 +907,26 @@ def slack_retrieval(
|
||||
# Execute searches in parallel
|
||||
results = run_functions_tuples_in_parallel(search_tasks)
|
||||
|
||||
# Calculate stats for consolidated logging
|
||||
total_raw_messages = sum(len(r.messages) for r in results)
|
||||
|
||||
# Merge and post-filter results
|
||||
slack_messages, docid_to_message = merge_slack_messages(results)
|
||||
slack_messages, docid_to_message, query_filtered_channels = merge_slack_messages(
|
||||
results
|
||||
)
|
||||
messages_after_dedup = len(slack_messages)
|
||||
|
||||
# Post-filter by channel type (DM, private channel, etc.)
|
||||
# NOTE: We must post-filter because Slack's search.messages API only supports
|
||||
# filtering by channel NAME (via in:#channel syntax), not by channel TYPE.
|
||||
# There's no way to specify "only public channels" or "exclude DMs" in the query.
|
||||
# Start with channels filtered during query execution, then add post-filter channels
|
||||
filtered_out_channels: set[str] = set(query_filtered_channels)
|
||||
if entities and team_id:
|
||||
# Use pre-fetched channel metadata to avoid cache misses
|
||||
# Pass it directly instead of relying on Redis cache
|
||||
|
||||
filtered_messages = []
|
||||
removed_count = 0
|
||||
for msg in slack_messages:
|
||||
# Pass pre-fetched metadata to avoid cache lookups
|
||||
channel_type = get_channel_type(
|
||||
@@ -786,22 +936,37 @@ def slack_retrieval(
|
||||
if should_include_message(channel_type, entities):
|
||||
filtered_messages.append(msg)
|
||||
else:
|
||||
removed_count += 1
|
||||
# Track unique channel name for summary
|
||||
channel_name = msg.metadata.get("channel", msg.channel_id)
|
||||
filtered_out_channels.add(f"{channel_name}({msg.channel_id})")
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"Post-filtering removed {removed_count} messages: "
|
||||
f"{len(slack_messages)} -> {len(filtered_messages)}"
|
||||
)
|
||||
slack_messages = filtered_messages
|
||||
|
||||
slack_messages = slack_messages[: limit or len(slack_messages)]
|
||||
|
||||
# Log consolidated summary with request ID for correlation
|
||||
request_id = (
|
||||
slack_event_context.message_ts[:10]
|
||||
if slack_event_context and slack_event_context.message_ts
|
||||
else "no-ctx"
|
||||
)
|
||||
logger.info(
|
||||
f"[req:{request_id}] Slack federated search: {len(search_tasks)} queries, "
|
||||
f"{total_raw_messages} raw msgs -> {messages_after_dedup} after dedup -> "
|
||||
f"{len(slack_messages)} final"
|
||||
+ (
|
||||
f", filtered channels: {sorted(filtered_out_channels)}"
|
||||
if filtered_out_channels
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if not slack_messages:
|
||||
return []
|
||||
|
||||
thread_texts: list[str] = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(get_contextualized_thread_text, (slack_message, access_token))
|
||||
(get_contextualized_thread_text, (slack_message, access_token, team_id))
|
||||
for slack_message in slack_messages
|
||||
]
|
||||
)
|
||||
|
||||
@@ -4,17 +4,16 @@ import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
@@ -34,35 +33,25 @@ WORD_PUNCTUATION = ".,!?;:\"'#"
|
||||
|
||||
RECENCY_KEYWORDS = ["recent", "latest", "newest", "last"]
|
||||
|
||||
|
||||
class ChannelTypeString(str, Enum):
|
||||
"""String representations of Slack channel types."""
|
||||
|
||||
IM = "im"
|
||||
MPIM = "mpim"
|
||||
PRIVATE_CHANNEL = "private_channel"
|
||||
PUBLIC_CHANNEL = "public_channel"
|
||||
|
||||
|
||||
# All Slack channel types for fetching metadata
|
||||
ALL_CHANNEL_TYPES = [
|
||||
ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
ChannelTypeString.IM.value,
|
||||
ChannelTypeString.MPIM.value,
|
||||
ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
ChannelType.PUBLIC_CHANNEL.value,
|
||||
ChannelType.IM.value,
|
||||
ChannelType.MPIM.value,
|
||||
ChannelType.PRIVATE_CHANNEL.value,
|
||||
]
|
||||
|
||||
# Map Slack API scopes to their corresponding channel types
|
||||
# This is used for graceful degradation when scopes are missing
|
||||
SCOPE_TO_CHANNEL_TYPE_MAP = {
|
||||
"mpim:read": ChannelTypeString.MPIM.value,
|
||||
"mpim:history": ChannelTypeString.MPIM.value,
|
||||
"im:read": ChannelTypeString.IM.value,
|
||||
"im:history": ChannelTypeString.IM.value,
|
||||
"groups:read": ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
"groups:history": ChannelTypeString.PRIVATE_CHANNEL.value,
|
||||
"channels:read": ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
"channels:history": ChannelTypeString.PUBLIC_CHANNEL.value,
|
||||
"mpim:read": ChannelType.MPIM.value,
|
||||
"mpim:history": ChannelType.MPIM.value,
|
||||
"im:read": ChannelType.IM.value,
|
||||
"im:history": ChannelType.IM.value,
|
||||
"groups:read": ChannelType.PRIVATE_CHANNEL.value,
|
||||
"groups:history": ChannelType.PRIVATE_CHANNEL.value,
|
||||
"channels:read": ChannelType.PUBLIC_CHANNEL.value,
|
||||
"channels:history": ChannelType.PUBLIC_CHANNEL.value,
|
||||
}
|
||||
|
||||
|
||||
@@ -201,9 +190,7 @@ def extract_date_range_from_query(
|
||||
|
||||
try:
|
||||
prompt = SLACK_DATE_EXTRACTION_PROMPT.format(query=query)
|
||||
response = message_to_string(
|
||||
llm.invoke_langchain([HumanMessage(content=prompt)])
|
||||
)
|
||||
response = llm_response_to_string(llm.invoke(prompt))
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -334,7 +321,7 @@ def build_channel_query_filter(
|
||||
def get_channel_type(
|
||||
channel_info: dict[str, Any] | None = None,
|
||||
channel_id: str | None = None,
|
||||
channel_metadata: dict[str, dict[str, Any]] | None = None,
|
||||
channel_metadata: dict[str, ChannelMetadata] | None = None,
|
||||
) -> ChannelType:
|
||||
"""
|
||||
Determine channel type from channel info dict or by looking up channel_id.
|
||||
@@ -361,11 +348,11 @@ def get_channel_type(
|
||||
ch_meta = channel_metadata.get(channel_id)
|
||||
if ch_meta:
|
||||
type_str = ch_meta.get("type")
|
||||
if type_str == ChannelTypeString.IM.value:
|
||||
if type_str == ChannelType.IM.value:
|
||||
return ChannelType.IM
|
||||
elif type_str == ChannelTypeString.MPIM.value:
|
||||
elif type_str == ChannelType.MPIM.value:
|
||||
return ChannelType.MPIM
|
||||
elif type_str == ChannelTypeString.PRIVATE_CHANNEL.value:
|
||||
elif type_str == ChannelType.PRIVATE_CHANNEL.value:
|
||||
return ChannelType.PRIVATE_CHANNEL
|
||||
return ChannelType.PUBLIC_CHANNEL
|
||||
|
||||
@@ -594,9 +581,7 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
)
|
||||
|
||||
try:
|
||||
response = message_to_string(
|
||||
llm.invoke_langchain([HumanMessage(content=prompt)])
|
||||
)
|
||||
response = llm_response_to_string(llm.invoke(prompt))
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -610,7 +595,7 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
return [""]
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Expanded query into {len(rephrased_queries)} queries: {rephrased_queries}"
|
||||
)
|
||||
return rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
An explanation of how the history of messages, tool calls, and docs are stored in the database:
|
||||
|
||||
Messages are grouped by a chat session, a tree structured is used to allow edits and for the
|
||||
user to switch between branches. Each ChatMessage is either a user message of an assistant message.
|
||||
user to switch between branches. Each ChatMessage is either a user message or an assistant message.
|
||||
It should always alternate between the two, System messages, custom agent prompt injections, and
|
||||
reminder messages are injected dynamically after the chat session is loaded into memory. The user
|
||||
and assistant messages are stored in pairs, though it is ok if the user message is stored and the
|
||||
|
||||
@@ -17,8 +17,10 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import USER_FILE_INDEXING_LIMIT
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -445,10 +447,12 @@ def set_cc_pair_repeated_error_state(
|
||||
values: dict = {"in_repeated_error_state": in_repeated_error_state}
|
||||
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts.
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# However, don't pause if there's an active manual indexing trigger,
|
||||
# which indicates the user wants to retry immediately.
|
||||
if in_repeated_error_state:
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if in_repeated_error_state and AUTH_TYPE == AuthType.CLOUD:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
|
||||
@@ -251,15 +251,27 @@ def upsert_llm_provider(
|
||||
|
||||
db_session.flush()
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
for model_configuration in llm_provider_upsert_request.model_configurations:
|
||||
# If max_input_tokens is not provided, look it up from LiteLLM
|
||||
max_input_tokens = model_configuration.max_input_tokens
|
||||
if max_input_tokens is None:
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=model_configuration.name,
|
||||
model_provider=llm_provider_upsert_request.provider,
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
name=model_configuration.name,
|
||||
is_visible=model_configuration.is_visible,
|
||||
max_input_tokens=model_configuration.max_input_tokens,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_configuration.supports_image_input,
|
||||
display_name=model_configuration.display_name,
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
)
|
||||
@@ -289,6 +301,56 @@ def upsert_llm_provider(
|
||||
return full_llm_provider
|
||||
|
||||
|
||||
def sync_model_configurations(
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
models: list[dict],
|
||||
) -> int:
|
||||
"""Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama).
|
||||
|
||||
This inserts NEW models from the source API without overwriting existing ones.
|
||||
User preferences (is_visible, max_input_tokens) are preserved for existing models.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_name: Name of the LLM provider
|
||||
models: List of model dicts with keys: name, display_name, max_input_tokens, supports_image_input
|
||||
|
||||
Returns:
|
||||
Number of new models added
|
||||
"""
|
||||
provider = fetch_existing_llm_provider(name=provider_name, db_session=db_session)
|
||||
if not provider:
|
||||
raise ValueError(f"LLM Provider '{provider_name}' not found")
|
||||
|
||||
# Get existing model names to count new additions
|
||||
existing_names = {mc.name for mc in provider.model_configurations}
|
||||
|
||||
new_count = 0
|
||||
for model in models:
|
||||
model_name = model["name"]
|
||||
if model_name not in existing_names:
|
||||
# Insert new model with is_visible=False (user must explicitly enable)
|
||||
db_session.execute(
|
||||
insert(ModelConfiguration)
|
||||
.values(
|
||||
llm_provider_id=provider.id,
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=model.get("max_input_tokens"),
|
||||
supports_image_input=model.get("supports_image_input", False),
|
||||
display_name=model.get("display_name"),
|
||||
)
|
||||
.on_conflict_do_nothing()
|
||||
)
|
||||
new_count += 1
|
||||
|
||||
if new_count > 0:
|
||||
db_session.commit()
|
||||
|
||||
return new_count
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
db_session: Session,
|
||||
) -> list[CloudEmbeddingProviderModel]:
|
||||
|
||||
@@ -2141,6 +2141,8 @@ class ChatMessage(Base):
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
# True if this assistant message is a clarification question (deep research flow)
|
||||
is_clarification: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Relationships
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
@@ -2437,6 +2439,11 @@ class ModelConfiguration(Base):
|
||||
|
||||
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
# Human-readable display name for the model.
|
||||
# For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API.
|
||||
# For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM.
|
||||
display_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
llm_provider: Mapped["LLMProvider"] = relationship(
|
||||
"LLMProvider",
|
||||
back_populates="model_configurations",
|
||||
|
||||
@@ -41,7 +41,7 @@ from onyx.server.features.persona.models import MinimalPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.features.tool.models import should_expose_tool_to_fe
|
||||
from onyx.server.features.tool.tool_visibility import should_expose_tool_to_fe
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -354,6 +354,17 @@ def _build_persona_filters(
|
||||
include_slack_bot_personas: bool,
|
||||
include_deleted: bool,
|
||||
) -> Select[tuple[Persona]]:
|
||||
"""Filters which Personas are included in the query.
|
||||
|
||||
Args:
|
||||
stmt: The base query to filter.
|
||||
include_default: If True, includes builtin/default personas.
|
||||
include_slack_bot_personas: If True, includes Slack bot personas.
|
||||
include_deleted: If True, includes deleted personas.
|
||||
|
||||
Returns:
|
||||
The modified query with the filters applied.
|
||||
"""
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
@@ -405,6 +416,9 @@ def get_persona_snapshots_for_user(
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
)
|
||||
|
||||
results = db_session.scalars(stmt).all()
|
||||
@@ -648,15 +662,14 @@ def get_raw_personas_for_user(
|
||||
include_slack_bot_personas: bool = False,
|
||||
include_deleted: bool = False,
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = _build_persona_filters(
|
||||
stmt, include_default, include_slack_bot_personas, include_deleted
|
||||
stmt = _build_persona_base_query(
|
||||
user, get_editable, include_default, include_slack_bot_personas, include_deleted
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_personas(db_session: Session) -> Sequence[Persona]:
|
||||
"""WARNING: Unsafe, can fetch personas from all users."""
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
||||
stmt = stmt.where(Persona.deleted.is_(False))
|
||||
@@ -701,19 +714,54 @@ def mark_delete_persona_by_name(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_all_personas_display_priority(
|
||||
def update_personas_display_priority(
|
||||
display_priority_map: dict[int, int],
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
commit_db_txn: bool = False,
|
||||
) -> None:
|
||||
"""Updates the display priority of all lives Personas"""
|
||||
personas = get_personas(db_session=db_session)
|
||||
available_persona_ids = {persona.id for persona in personas}
|
||||
if available_persona_ids != set(display_priority_map.keys()):
|
||||
raise ValueError("Invalid persona IDs provided")
|
||||
"""Updates the display priorities of the specified Personas.
|
||||
|
||||
for persona in personas:
|
||||
persona.display_priority = display_priority_map[persona.id]
|
||||
db_session.commit()
|
||||
Args:
|
||||
display_priority_map: A map of persona IDs to intended display
|
||||
priorities.
|
||||
db_session: Database session for executing queries.
|
||||
user: The user to filter personas for. If None and auth is disabled,
|
||||
assumes the user is an admin. Otherwise, if None shows only public
|
||||
personas.
|
||||
commit_db_txn: If True, commits the database transaction after
|
||||
updating the display priorities. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: The caller tried to update a persona for which the user does
|
||||
not have access.
|
||||
"""
|
||||
# No-op to save a query if it is not necessary.
|
||||
if len(display_priority_map) == 0:
|
||||
return
|
||||
|
||||
personas = get_raw_personas_for_user(
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
include_default=True,
|
||||
include_slack_bot_personas=True,
|
||||
include_deleted=True,
|
||||
)
|
||||
available_personas_map: dict[int, Persona] = {
|
||||
persona.id: persona for persona in personas
|
||||
}
|
||||
|
||||
for persona_id, priority in display_priority_map.items():
|
||||
if persona_id not in available_personas_map:
|
||||
raise ValueError(
|
||||
f"Invalid persona ID provided: Persona with ID {persona_id} was not found for this user."
|
||||
)
|
||||
|
||||
available_personas_map[persona_id].display_priority = priority
|
||||
|
||||
if commit_db_txn:
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
@@ -1053,7 +1101,7 @@ def get_persona_by_id(
|
||||
def get_personas_by_ids(
|
||||
persona_ids: list[int], db_session: Session
|
||||
) -> Sequence[Persona]:
|
||||
"""Unsafe, can fetch personas from all users"""
|
||||
"""WARNING: Unsafe, can fetch personas from all users."""
|
||||
if not persona_ids:
|
||||
return []
|
||||
personas = db_session.scalars(
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -43,17 +44,26 @@ def create_or_add_document_tag(
|
||||
if not document:
|
||||
raise ValueError("Invalid Document, cannot attach Tags")
|
||||
|
||||
# Use upsert to avoid race condition when multiple workers try to create the same tag
|
||||
insert_stmt = pg_insert(Tag).values(
|
||||
tag_key=tag_key,
|
||||
tag_value=tag_value,
|
||||
source=source,
|
||||
is_list=False,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
# Now fetch the tag (either just inserted or already existed)
|
||||
tag_stmt = select(Tag).where(
|
||||
Tag.tag_key == tag_key,
|
||||
Tag.tag_value == tag_value,
|
||||
Tag.source == source,
|
||||
Tag.is_list.is_(False),
|
||||
)
|
||||
tag = db_session.execute(tag_stmt).scalar_one_or_none()
|
||||
|
||||
if not tag:
|
||||
tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source, is_list=False)
|
||||
db_session.add(tag)
|
||||
tag = db_session.execute(tag_stmt).scalar_one()
|
||||
|
||||
if tag not in document.tags:
|
||||
document.tags.append(tag)
|
||||
@@ -79,31 +89,27 @@ def create_or_add_document_tag_list(
|
||||
if not document:
|
||||
raise ValueError("Invalid Document, cannot attach Tags")
|
||||
|
||||
existing_tags_stmt = select(Tag).where(
|
||||
# Use upsert to avoid race condition when multiple workers try to create the same tags
|
||||
for tag_value in valid_tag_values:
|
||||
insert_stmt = pg_insert(Tag).values(
|
||||
tag_key=tag_key,
|
||||
tag_value=tag_value,
|
||||
source=source,
|
||||
is_list=True,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
# Now fetch all tags (either just inserted or already existed)
|
||||
all_tags_stmt = select(Tag).where(
|
||||
Tag.tag_key == tag_key,
|
||||
Tag.tag_value.in_(valid_tag_values),
|
||||
Tag.source == source,
|
||||
Tag.is_list.is_(True),
|
||||
)
|
||||
existing_tags = list(db_session.execute(existing_tags_stmt).scalars().all())
|
||||
existing_tag_values = {tag.tag_value for tag in existing_tags}
|
||||
|
||||
new_tags = []
|
||||
for tag_value in valid_tag_values:
|
||||
if tag_value not in existing_tag_values:
|
||||
new_tag = Tag(
|
||||
tag_key=tag_key, tag_value=tag_value, source=source, is_list=True
|
||||
)
|
||||
db_session.add(new_tag)
|
||||
new_tags.append(new_tag)
|
||||
existing_tag_values.add(tag_value)
|
||||
|
||||
if new_tags:
|
||||
logger.debug(
|
||||
f"Created new tags: {', '.join([f'{tag.tag_key}:{tag.tag_value}' for tag in new_tags])}"
|
||||
)
|
||||
|
||||
all_tags = existing_tags + new_tags
|
||||
all_tags = list(db_session.execute(all_tags_stmt).scalars().all())
|
||||
|
||||
for tag in all_tags:
|
||||
if tag not in document.tags:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -49,7 +50,12 @@ def get_tools(
|
||||
query = query.where(Tool.enabled.is_(True))
|
||||
|
||||
if only_openapi:
|
||||
query = query.where(Tool.openapi_schema.is_not(None))
|
||||
query = query.where(
|
||||
Tool.openapi_schema.is_not(None),
|
||||
# To avoid showing rows that have JSON literal `null` stored in the column to the user.
|
||||
# tools from mcp servers will not have an openapi schema but it has `null`, so we need to exclude them.
|
||||
func.jsonb_typeof(Tool.openapi_schema) == "object",
|
||||
)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
0
backend/onyx/deep_research/__init__.py
Normal file
0
backend/onyx/deep_research/__init__.py
Normal file
254
backend/onyx/deep_research/dr_loop.py
Normal file
254
backend/onyx/deep_research/dr_loop.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# TODO: Notes for potential extensions and future improvements:
|
||||
# 1. Allow tools that aren't search specific tools
|
||||
# 2. Use user provided custom prompts
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT
|
||||
from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING
|
||||
from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT
|
||||
from onyx.prompts.prompt_utils import get_current_llm_day_time
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta
|
||||
from onyx.server.query_and_chat.streaming_models import DeepResearchPlanStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_USER_MESSAGES_FOR_CONTEXT = 5
|
||||
MAX_ORCHESTRATOR_CYCLES = 8
|
||||
|
||||
|
||||
def run_deep_research_llm_loop(
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
simple_chat_history: list[ChatMessageSimple],
|
||||
tools: list[Tool],
|
||||
custom_agent_prompt: str | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
skip_clarification: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> None:
|
||||
# Here for lazy load LiteLLM
|
||||
from onyx.llm.litellm_singleton.config import initialize_litellm
|
||||
|
||||
# An approximate limit. In extreme cases it may still fail but this should allow deep research
|
||||
# to work in most cases.
|
||||
if llm.config.max_input_tokens < 25000:
|
||||
raise RuntimeError(
|
||||
"Cannot run Deep Research with an LLM that has less than 25,000 max input tokens"
|
||||
)
|
||||
|
||||
initialize_litellm()
|
||||
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
|
||||
llm_step_result: LlmStepResult | None = None
|
||||
|
||||
# Filter tools to only allow web search, internal search, and open URL
|
||||
allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME}
|
||||
[tool for tool in tools if tool.name in allowed_tool_names]
|
||||
|
||||
#########################################################
|
||||
# CLARIFICATION STEP (optional)
|
||||
#########################################################
|
||||
if not skip_clarification:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
step_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
# Consume the generator, emitting packets and capturing the final result
|
||||
while True:
|
||||
try:
|
||||
packet = next(step_generator)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
|
||||
# Type narrowing: generator always returns a result, so this can't be None
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(Packet(turn_index=0, obj=OverallStop(type="stop")))
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
turn_index=0,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
turn_index=packet.turn_index,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, _ = e.value
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
research_plan = llm_step_result.answer
|
||||
|
||||
#########################################################
|
||||
# RESEARCH EXECUTION STEP
|
||||
#########################################################
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT if not is_reasoning_model else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
for cycle in range(MAX_ORCHESTRATOR_CYCLES):
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=MAX_ORCHESTRATOR_CYCLES,
|
||||
research_plan=research_plan,
|
||||
)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
turn_index=cycle,
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
18
backend/onyx/deep_research/dr_mock_tools.py
Normal file
18
backend/onyx/deep_research/dr_mock_tools.py
Normal file
@@ -0,0 +1,18 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
|
||||
def get_clarification_tool_definitions() -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": GENERATE_PLAN_TOOL_NAME,
|
||||
"description": "No clarification needed, generate a research plan for the user's query.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
325
backend/onyx/document_index/interfaces_new.py
Normal file
325
backend/onyx/document_index/interfaces_new.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
# NOTE: "Document" in the naming convention is used to refer to the entire document as represented in Onyx.
|
||||
# What is actually stored in the index is the document chunks. By the terminology of most search engines / vector
|
||||
# databases, the individual objects stored are called documents, but in this case it refers to a chunk.
|
||||
|
||||
# Outside of searching and update capabilities, the document index must also implement the ability to port all of
|
||||
# the documents over to a secondary index. This allows for embedding models to be updated and for porting documents
|
||||
# to happen in the background while the primary index still serves the main traffic.
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Main interfaces - these are what you should inherit from
|
||||
"DocumentIndex",
|
||||
# Data models - used in method signatures
|
||||
"DocumentInsertionRecord",
|
||||
"DocumentSectionRequest",
|
||||
"IndexingMetadata",
|
||||
"MetadataUpdateRequest",
|
||||
# Capability mixins - for custom compositions or type checking
|
||||
"SchemaVerifiable",
|
||||
"Indexable",
|
||||
"Deletable",
|
||||
"Updatable",
|
||||
"IdRetrievalCapable",
|
||||
"HybridCapable",
|
||||
"RandomCapable",
|
||||
]
|
||||
|
||||
|
||||
class DocumentInsertionRecord(BaseModel):
|
||||
"""
|
||||
Result of indexing a document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_id: str
|
||||
already_existed: bool
|
||||
|
||||
|
||||
class DocumentSectionRequest(BaseModel):
|
||||
"""
|
||||
Request for a document section or whole document
|
||||
If no min_chunk_ind is provided it should start at the beginning of the document
|
||||
If no max_chunk_ind is provided it should go to the end of the document
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
document_id: str
|
||||
min_chunk_ind: int | None = None
|
||||
max_chunk_ind: int | None = None
|
||||
|
||||
|
||||
class IndexingMetadata(BaseModel):
|
||||
"""
|
||||
Information about chunk counts for efficient cleaning / updating of document chunks. A common pattern to ensure
|
||||
that no chunks are left over is to delete all of the chunks for a document and then re-index the document. This
|
||||
information allows us to only delete the extra "tail" chunks when the document has gotten shorter.
|
||||
"""
|
||||
|
||||
# The tuple is (old_chunk_cnt, new_chunk_cnt)
|
||||
doc_id_to_chunk_cnt_diff: dict[str, tuple[int, int]]
|
||||
|
||||
|
||||
class MetadataUpdateRequest(BaseModel):
|
||||
"""
|
||||
Updates to the documents that can happen without there being an update to the contents of the document.
|
||||
"""
|
||||
|
||||
document_ids: list[str]
|
||||
# Passed in to help with potential optimizations of the implementation
|
||||
doc_id_to_chunk_cnt: dict[str, int]
|
||||
# For the ones that are None, there is no update required to that field
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
|
||||
|
||||
class SchemaVerifiable(abc.ABC):
|
||||
"""
|
||||
Class must implement document index schema verification. For example, verify that all of the
|
||||
necessary attributes for indexing, querying, filtering, and fields to return from search are
|
||||
all valid in the schema.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
tenant_id: int | None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_name = index_name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abc.abstractmethod
|
||||
def verify_and_create_index_if_necessary(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code. For certain search
|
||||
engines, the schema needs to be created before indexing can happen. This call should create the schema if it
|
||||
does not exist.
|
||||
|
||||
Parameters:
|
||||
- embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- embedding_precision: Precision of the vector similarity part of the search
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Indexable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to index document chunks
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: Iterator[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index. This is often a batch operation
|
||||
including chunks from multiple documents.
|
||||
|
||||
NOTE: When a document is reindexed/updated here and has gotten shorter, it is important to delete the extra
|
||||
chunks at the end to ensure there are no stale chunks in the index.
|
||||
|
||||
NOTE: The chunks of a document are never separated into separate index() calls. So there is
|
||||
no worry of receiving the first 0 through n chunks in one index call and the next n through
|
||||
m chunks of a document in the next index call.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document index.
|
||||
- indexing_metadata: Information about chunk counts for efficient cleaning / updating
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
when updating, as well as if the document is newly indexed or already existed and
|
||||
just updated
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deletable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to delete document by a given unique document id. Note that the document id is the
|
||||
unique identifier for the document as represented in Onyx, not in the document index.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
db_doc_id: str,
|
||||
*,
|
||||
# Passed in in case it helps the efficiency of the delete implementation
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
Given a single document, hard delete all of the chunks for the document from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_id: document id as represented in Onyx
|
||||
- chunk_count: number of chunks in the document
|
||||
|
||||
Returns:
|
||||
number of chunks deleted
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Updatable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to update certain attributes of a document without needing to
|
||||
update all of the fields. Specifically, needs to be able to update:
|
||||
- Access Control List
|
||||
- Document-set membership
|
||||
- Boost value (learning from feedback mechanism)
|
||||
- Whether the document is hidden or not, hidden documents are not returned from search
|
||||
- Which Projects the document is a part of
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[MetadataUpdateRequest]) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
None values mean that the field does not need an update.
|
||||
|
||||
Parameters:
|
||||
- update_requests: for a list of document ids in the update request, apply the same updates
|
||||
to all of the documents with those ids. This is for bulk handling efficiency. Many
|
||||
updates are done at the connector level which have many documents for the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdRetrievalCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to retrieve either:
|
||||
- All of the chunks of a document IN ORDER given a document id. Caller assumes it to be in order.
|
||||
- A specific section (continuous set of chunks) for some document.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Fetch chunk(s) based on document id
|
||||
|
||||
NOTE: This is used to reconstruct a full document or an extended (multi-chunk) section
|
||||
of a document. Downstream currently assumes that the chunking does not introduce overlaps
|
||||
between the chunks. If there are overlaps for the chunks, then the reconstructed document
|
||||
or extended section will have duplicate segments.
|
||||
|
||||
NOTE: This should be used after a search call to get more context around returned chunks.
|
||||
There is no filters here since the calling code should not be calling this on arbitrary
|
||||
documents.
|
||||
|
||||
Parameters:
|
||||
- chunk_requests: requests containing the document id and the chunk range to retrieve
|
||||
|
||||
Returns:
|
||||
list of sections from the documents specified
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HybridCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement hybrid (keyword + vector) search functionality
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
|
||||
Parameters:
|
||||
- query: unmodified user query. This may be needed for getting the matching highlighted
|
||||
keywords or for logging purposes
|
||||
- query_embedding: vector representation of the query, must be of the correct
|
||||
dimensionality for the primary index
|
||||
- final_keywords: Final keywords to be used from the query, defaults to query if not set
|
||||
- query_type: Semantic or keyword type query, may use different scoring logic for each
|
||||
- filters: Filters for things like permissions, source type, time, etc.
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
|
||||
Returns:
|
||||
Score ranked (highest first) list of highest matching chunks
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability.
|
||||
This currently is just used for porting the documents to a secondary index."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters | None = None,
|
||||
num_to_retrieve: int = 100,
|
||||
dirty: bool | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DocumentIndex(
|
||||
SchemaVerifiable,
|
||||
Indexable,
|
||||
Updatable,
|
||||
Deletable,
|
||||
HybridCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
A valid document index that can plug into all Onyx flows must implement all of these
|
||||
functionalities.
|
||||
|
||||
As a high level summary, document indices need to be able to
|
||||
- Verify the schema definition is valid
|
||||
- Index new documents
|
||||
- Update specific attributes of existing documents
|
||||
- Delete documents
|
||||
- Run hybrid search
|
||||
- Retrieve document or sections of documents based on document id
|
||||
- Retrieve sets of random documents
|
||||
"""
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm.session import SessionTransaction
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.evals.models import EvalationAck
|
||||
@@ -72,11 +73,15 @@ def _get_answer(
|
||||
request = prepare_chat_message_request(
|
||||
message_text=eval_input["message"],
|
||||
user=user,
|
||||
filters=None,
|
||||
persona_id=None,
|
||||
persona_override_config=full_configuration.persona_override_config,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=RetrievalDetails(),
|
||||
rerank_settings=None,
|
||||
db_session=db_session,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
llm_override=full_configuration.llm,
|
||||
use_agentic_search=False,
|
||||
allowed_tool_ids=full_configuration.allowed_tool_ids,
|
||||
)
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -6,6 +6,9 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import PromptOverrideConfig
|
||||
from onyx.chat.models import ToolConfig
|
||||
from onyx.db.tools import get_builtin_tool
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
|
||||
@@ -13,6 +16,7 @@ from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP
|
||||
|
||||
class EvalConfiguration(BaseModel):
|
||||
builtin_tool_types: list[str] = Field(default_factory=list)
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = Field(default_factory=LLMOverride)
|
||||
search_permissions_email: str | None = None
|
||||
allowed_tool_ids: list[int]
|
||||
@@ -20,6 +24,7 @@ class EvalConfiguration(BaseModel):
|
||||
|
||||
class EvalConfigurationOptions(BaseModel):
|
||||
builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys())
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
llm: LLMOverride = LLMOverride(
|
||||
model_provider="Default",
|
||||
model_version="gpt-4.1",
|
||||
@@ -30,7 +35,25 @@ class EvalConfigurationOptions(BaseModel):
|
||||
no_send_logs: bool = False
|
||||
|
||||
def get_configuration(self, db_session: Session) -> EvalConfiguration:
|
||||
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
|
||||
name="Eval",
|
||||
description="A persona for evaluation",
|
||||
tools=[
|
||||
ToolConfig(id=get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id)
|
||||
for tool in self.builtin_tool_types
|
||||
],
|
||||
prompts=[
|
||||
PromptOverrideConfig(
|
||||
name="Default",
|
||||
description="Default prompt for evaluation",
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
return EvalConfiguration(
|
||||
persona_override_config=persona_override_config,
|
||||
llm=self.llm,
|
||||
search_permissions_email=self.search_permissions_email,
|
||||
allowed_tool_ids=[
|
||||
|
||||
@@ -41,48 +41,64 @@ def get_federated_retrieval_functions(
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
logger.info("Slack context detected, checking for Slack bot setup...")
|
||||
logger.debug("Slack context detected, checking for Slack bot setup...")
|
||||
|
||||
# If document_set_names is specified, check if any Slack federated connector
|
||||
# is associated with those document sets before enabling Slack federated search
|
||||
if document_set_names:
|
||||
slack_federated_mappings = (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names(
|
||||
db_session, document_set_names
|
||||
)
|
||||
# Slack federated search requires a Slack federated connector to be linked
|
||||
# via document sets. If no document sets are provided, skip Slack federated search.
|
||||
if not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping Slack federated search: no document sets provided, "
|
||||
"Slack federated connector must be linked via document sets"
|
||||
)
|
||||
# Check if any of the mappings are for a Slack federated connector
|
||||
has_slack_federated_connector = any(
|
||||
mapping.federated_connector.source
|
||||
return []
|
||||
|
||||
# Check if any Slack federated connector is associated with the document sets
|
||||
# and extract its config (entities) for channel filtering
|
||||
slack_federated_connector_config: dict[str, Any] | None = None
|
||||
slack_federated_mappings = (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names(
|
||||
db_session, document_set_names
|
||||
)
|
||||
)
|
||||
for mapping in slack_federated_mappings:
|
||||
if (
|
||||
mapping.federated_connector is not None
|
||||
and mapping.federated_connector.source
|
||||
== FederatedConnectorSource.FEDERATED_SLACK
|
||||
for mapping in slack_federated_mappings
|
||||
if mapping.federated_connector is not None
|
||||
)
|
||||
if not has_slack_federated_connector:
|
||||
logger.info(
|
||||
f"Skipping Slack federated search: document sets {document_set_names} "
|
||||
"are not associated with any Slack federated connector"
|
||||
):
|
||||
slack_federated_connector_config = (
|
||||
mapping.federated_connector.config or {}
|
||||
)
|
||||
# Return empty list - no Slack federated search for this context
|
||||
return []
|
||||
logger.debug(
|
||||
f"Found Slack federated connector config: {slack_federated_connector_config}"
|
||||
)
|
||||
break
|
||||
|
||||
if slack_federated_connector_config is None:
|
||||
logger.debug(
|
||||
f"Skipping Slack federated search: document sets {document_set_names} "
|
||||
"are not associated with any Slack federated connector"
|
||||
)
|
||||
# Return empty list - no Slack federated search for this context
|
||||
return []
|
||||
|
||||
try:
|
||||
slack_bots = fetch_slack_bots(db_session)
|
||||
logger.info(f"Found {len(slack_bots)} Slack bots")
|
||||
logger.debug(f"Found {len(slack_bots)} Slack bots")
|
||||
|
||||
# First try to find a bot with user token
|
||||
tenant_slack_bot = next(
|
||||
(bot for bot in slack_bots if bot.enabled and bot.user_token), None
|
||||
)
|
||||
if tenant_slack_bot:
|
||||
logger.info(f"Selected bot with user_token: {tenant_slack_bot.name}")
|
||||
logger.debug(f"Selected bot with user_token: {tenant_slack_bot.name}")
|
||||
else:
|
||||
# Fall back to any enabled bot without user token
|
||||
tenant_slack_bot = next(
|
||||
(bot for bot in slack_bots if bot.enabled), None
|
||||
)
|
||||
if tenant_slack_bot:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Selected bot without user_token: {tenant_slack_bot.name} (limited functionality)"
|
||||
)
|
||||
else:
|
||||
@@ -113,16 +129,23 @@ def get_federated_retrieval_functions(
|
||||
# Capture variables by value to avoid lambda closure issues
|
||||
bot_token = tenant_slack_bot.bot_token
|
||||
|
||||
# Use connector config for channel filtering (guaranteed to exist at this point)
|
||||
connector_entities = slack_federated_connector_config
|
||||
logger.debug(
|
||||
f"Using Slack federated connector entities for bot context: {connector_entities}"
|
||||
)
|
||||
|
||||
def create_slack_retrieval_function(
|
||||
conn: FederatedConnector,
|
||||
token: str,
|
||||
ctx: SlackContext,
|
||||
bot_tok: str,
|
||||
entities: dict[str, Any],
|
||||
) -> Callable[[ChunkIndexRequest], list[InferenceChunk]]:
|
||||
def retrieval_fn(query: ChunkIndexRequest) -> list[InferenceChunk]:
|
||||
return conn.search(
|
||||
query,
|
||||
{}, # Empty entities for Slack context
|
||||
entities, # Use connector-level entities for channel filtering
|
||||
access_token=token,
|
||||
limit=None, # Let connector use its own max_messages_per_query config
|
||||
slack_event_context=ctx,
|
||||
@@ -134,12 +157,16 @@ def get_federated_retrieval_functions(
|
||||
federated_retrieval_infos_slack.append(
|
||||
FederatedRetrievalInfo(
|
||||
retrieval_function=create_slack_retrieval_function(
|
||||
connector, access_token, slack_context, bot_token
|
||||
connector,
|
||||
access_token,
|
||||
slack_context,
|
||||
bot_token,
|
||||
connector_entities,
|
||||
),
|
||||
source=FederatedConnectorSource.FEDERATED_SLACK,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Added Slack federated search for bot, returning {len(federated_retrieval_infos_slack)} retrieval functions"
|
||||
)
|
||||
return federated_retrieval_infos_slack
|
||||
@@ -177,7 +204,7 @@ def get_federated_retrieval_functions(
|
||||
|
||||
# If no source types are specified, don't use any federated connectors
|
||||
if source_types is None:
|
||||
logger.info("No source types specified, skipping all federated connectors")
|
||||
logger.debug("No source types specified, skipping all federated connectors")
|
||||
return []
|
||||
|
||||
federated_retrieval_infos: list[FederatedRetrievalInfo] = []
|
||||
|
||||
@@ -290,7 +290,7 @@ class SlackFederatedConnector(FederatedConnector):
|
||||
Returns:
|
||||
Search results in SlackSearchResponse format
|
||||
"""
|
||||
logger.info(f"Slack federated search called with entities: {entities}")
|
||||
logger.debug(f"Slack federated search called with entities: {entities}")
|
||||
|
||||
# Get team_id from Slack API for caching and filtering
|
||||
team_id = None
|
||||
@@ -302,7 +302,7 @@ class SlackFederatedConnector(FederatedConnector):
|
||||
# Cast response.data to dict for type checking
|
||||
auth_data: dict[str, Any] = auth_response.data # type: ignore
|
||||
team_id = auth_data.get("team_id")
|
||||
logger.info(f"Slack team_id: {team_id}")
|
||||
logger.debug(f"Slack team_id: {team_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch team_id from Slack API: {e}")
|
||||
|
||||
|
||||
@@ -25,17 +25,17 @@ class SlackEntities(BaseModel):
|
||||
|
||||
# Direct message filtering
|
||||
include_dm: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include user direct messages in search results",
|
||||
)
|
||||
include_group_dm: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include group direct messages (multi-person DMs) in search results",
|
||||
)
|
||||
|
||||
# Private channel filtering
|
||||
include_private_channels: bool = Field(
|
||||
default=False,
|
||||
default=True,
|
||||
description="Include private channels in search results (user must have access)",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import ContentPart
|
||||
from onyx.llm.models import ImageContentPart
|
||||
from onyx.llm.models import ImageUrlDetail
|
||||
from onyx.llm.models import SystemMessage
|
||||
from onyx.llm.models import TextContentPart
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -97,22 +101,24 @@ def _summarize_image(
|
||||
) -> str:
|
||||
"""Use default LLM (if it is multimodal) to generate a summary of an image."""
|
||||
|
||||
messages: list[BaseMessage] = []
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
|
||||
content: list[ContentPart] = []
|
||||
if query:
|
||||
content.append(TextContentPart(text=query))
|
||||
content.append(ImageContentPart(image_url=ImageUrlDetail(url=encoded_image)))
|
||||
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
content=[
|
||||
{"type": "text", "text": query},
|
||||
{"type": "image_url", "image_url": {"url": encoded_image}},
|
||||
],
|
||||
UserMessage(
|
||||
content=content,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
return message_to_string(llm.invoke_langchain(messages))
|
||||
return llm_response_to_string(llm.invoke(messages))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Summarization failed. Messages: {messages}"
|
||||
|
||||
@@ -298,17 +298,17 @@ def verify_user_files(
|
||||
|
||||
for file_descriptor in user_files:
|
||||
# Check if this file descriptor has a user_file_id
|
||||
if "user_file_id" in file_descriptor and file_descriptor["user_file_id"]:
|
||||
if file_descriptor.get("user_file_id"):
|
||||
try:
|
||||
user_file_ids.append(UUID(file_descriptor["user_file_id"]))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f"Invalid user_file_id in file descriptor: {file_descriptor.get('user_file_id')}"
|
||||
f"Invalid user_file_id in file descriptor: {file_descriptor['user_file_id']}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# This is a project file - use the 'id' field which is the file_id
|
||||
if "id" in file_descriptor and file_descriptor["id"]:
|
||||
if file_descriptor.get("id"):
|
||||
project_file_ids.append(file_descriptor["id"])
|
||||
|
||||
# Verify user files (existing logic)
|
||||
|
||||
@@ -54,8 +54,8 @@ from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
@@ -542,8 +542,8 @@ def add_document_summaries(
|
||||
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
|
||||
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
|
||||
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
|
||||
doc_summary = message_to_string(
|
||||
llm.invoke_langchain(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
|
||||
doc_summary = llm_response_to_string(
|
||||
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
|
||||
)
|
||||
|
||||
for chunk in chunks_by_doc:
|
||||
@@ -583,8 +583,8 @@ def add_chunk_summaries(
|
||||
if not doc_info:
|
||||
# This happens if the document is too long AND document summaries are turned off
|
||||
# In this case we compute a doc summary using the LLM
|
||||
doc_info = message_to_string(
|
||||
llm.invoke_langchain(
|
||||
doc_info = llm_response_to_string(
|
||||
llm.invoke(
|
||||
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
@@ -595,8 +595,8 @@ def add_chunk_summaries(
|
||||
def assign_context(chunk: DocAwareChunk) -> None:
|
||||
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
|
||||
try:
|
||||
chunk.chunk_context = message_to_string(
|
||||
llm.invoke_langchain(
|
||||
chunk.chunk_context = llm_response_to_string(
|
||||
llm.invoke(
|
||||
context_prompt1 + context_prompt2,
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
|
||||
@@ -80,7 +80,11 @@ class PgRedisKVStore(KeyValueStore):
|
||||
value = None
|
||||
|
||||
try:
|
||||
self.redis_client.set(REDIS_KEY_PREFIX + key, json.dumps(value))
|
||||
self.redis_client.set(
|
||||
REDIS_KEY_PREFIX + key,
|
||||
json.dumps(value),
|
||||
ex=KV_REDIS_KEY_EXPIRATION,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set value in Redis for key '{key}': {str(e)}")
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCallTypes
|
||||
from onyx.configs.kg_configs import KG_METADATA_TRACKING_THRESHOLD
|
||||
@@ -31,7 +29,7 @@ from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_type_id
|
||||
from onyx.kg.vespa.vespa_interactions import get_document_vespa_contents
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT
|
||||
@@ -418,14 +416,10 @@ def kg_classify_document(
|
||||
|
||||
# classify with LLM
|
||||
primary_llm, _ = get_default_llms()
|
||||
msg = [HumanMessage(content=prompt)]
|
||||
try:
|
||||
raw_classification_result = primary_llm.invoke_langchain(msg)
|
||||
raw_classification_result = llm_response_to_string(primary_llm.invoke(prompt))
|
||||
classification_result = (
|
||||
message_to_string(raw_classification_result)
|
||||
.replace("```json", "")
|
||||
.replace("```", "")
|
||||
.strip()
|
||||
raw_classification_result.replace("```json", "").replace("```", "").strip()
|
||||
)
|
||||
# no json parsing here because of reasoning output
|
||||
classification_class = classification_result.split("CATEGORY:")[1].strip()
|
||||
@@ -486,12 +480,10 @@ def kg_deep_extract_chunks(
|
||||
|
||||
# extract with LLM
|
||||
_, fast_llm = get_default_llms()
|
||||
msg = [HumanMessage(content=prompt)]
|
||||
try:
|
||||
raw_extraction_result = fast_llm.invoke_langchain(msg)
|
||||
raw_extraction_result = llm_response_to_string(fast_llm.invoke(prompt))
|
||||
cleaned_response = (
|
||||
message_to_string(raw_extraction_result)
|
||||
.replace("{{", "{")
|
||||
raw_extraction_result.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
|
||||
@@ -1,45 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from httpx import RemoteProtocolError
|
||||
from langchain.schema.language_model import (
|
||||
LanguageModelInput as LangChainLanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessageChunk
|
||||
from langchain_core.messages import ChatMessage
|
||||
from langchain_core.messages import ChatMessageChunk
|
||||
from langchain_core.messages import FunctionMessage
|
||||
from langchain_core.messages import FunctionMessageChunk
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import HumanMessageChunk
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.app_configs import SEND_USER_METADATA_TO_LLM_PROVIDER
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import STANDARD_TOOL_CHOICE_OPTIONS
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ReasoningEffort
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.llm_provider_options import AZURE_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
@@ -47,6 +25,8 @@ from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.models import CLAUDE_REASONING_BUDGET_TOKENS
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.utils import is_true_openai_model
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.utils import mask_string
|
||||
@@ -57,14 +37,13 @@ from onyx.utils.special_types import JSON_ro
|
||||
logger = setup_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper, Message
|
||||
from litellm import CustomStreamWrapper
|
||||
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
|
||||
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
|
||||
LegacyPromptDict = Sequence[str | list[str] | dict[str, Any] | tuple[str, str]]
|
||||
MAX_LITELLM_USER_ID_LENGTH = 64
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
@@ -79,199 +58,30 @@ class LLMRateLimitError(Exception):
|
||||
"""
|
||||
|
||||
|
||||
def _base_msg_to_role(msg: BaseMessage) -> str:
|
||||
if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk):
|
||||
return "user"
|
||||
if isinstance(msg, AIMessage) or isinstance(msg, AIMessageChunk):
|
||||
return "assistant"
|
||||
if isinstance(msg, SystemMessage) or isinstance(msg, SystemMessageChunk):
|
||||
return "system"
|
||||
if isinstance(msg, FunctionMessage) or isinstance(msg, FunctionMessageChunk):
|
||||
return "function"
|
||||
return "unknown"
|
||||
def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
"""Convert Pydantic message models to dictionaries for LiteLLM.
|
||||
|
||||
|
||||
def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: "Message",
|
||||
) -> BaseMessage:
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
|
||||
# Handling function calls and tool calls if present
|
||||
tool_calls = (
|
||||
cast(
|
||||
list[litellm.ChatCompletionMessageToolCall],
|
||||
litellm_message.tool_calls,
|
||||
)
|
||||
if hasattr(litellm_message, "tool_calls")
|
||||
else []
|
||||
)
|
||||
|
||||
# Create the appropriate langchain message based on the role
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(
|
||||
content=content,
|
||||
tool_calls=(
|
||||
[
|
||||
{
|
||||
"name": tool_call.function.name or "",
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else []
|
||||
),
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role type received: {role}")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_message_to_dict"""
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [
|
||||
{
|
||||
"id": tool_call.get("id"),
|
||||
"function": {
|
||||
"name": tool_call["name"],
|
||||
"arguments": json.dumps(tool_call["args"]),
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_call.get("index", 0),
|
||||
}
|
||||
for tool_call in message.tool_calls
|
||||
]
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": message.content,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: dict[str, Any],
|
||||
curr_msg: BaseMessage | None,
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
from litellm.utils import ChatCompletionDeltaToolCall
|
||||
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs.update({"function_call": dict(_dict["function_call"])})
|
||||
tool_calls = cast(list[ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls"))
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
# NOTE: if tool calls are present, then it's an assistant.
|
||||
# In Ollama, the role will be None for tool-calls
|
||||
elif role == "assistant" or tool_calls:
|
||||
if tool_calls:
|
||||
tool_call = tool_calls[0]
|
||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||
idx = tool_call.index
|
||||
|
||||
tool_call_chunk = ToolCallChunk(
|
||||
name=tool_name,
|
||||
id=tool_call.id,
|
||||
args=tool_call.function.arguments,
|
||||
index=idx,
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
tool_call_chunks=[tool_call_chunk],
|
||||
additional_kwargs={
|
||||
"usage_metadata": {"stop": stop_reason},
|
||||
**additional_kwargs,
|
||||
},
|
||||
)
|
||||
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs={
|
||||
"usage_metadata": {"stop": stop_reason},
|
||||
**additional_kwargs,
|
||||
},
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function":
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
|
||||
|
||||
def _prompt_to_dict(
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
) -> LegacyPromptDict:
|
||||
# NOTE: this must go first, since it is also a Sequence
|
||||
LiteLLM expects messages to be dictionaries (with .get() method),
|
||||
not Pydantic models. This function serializes the messages.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
return [_convert_message_to_dict(HumanMessage(content=prompt))]
|
||||
|
||||
if isinstance(prompt, (list, Sequence)):
|
||||
normalized_prompt: list[str | list[str] | dict[str, Any] | tuple[str, str]] = []
|
||||
for msg in prompt:
|
||||
if isinstance(msg, BaseMessage):
|
||||
normalized_prompt.append(_convert_message_to_dict(msg))
|
||||
elif isinstance(msg, dict):
|
||||
normalized_prompt.append(dict(msg))
|
||||
else:
|
||||
normalized_prompt.append(msg)
|
||||
return normalized_prompt
|
||||
|
||||
if isinstance(prompt, BaseMessage):
|
||||
return [_convert_message_to_dict(prompt)]
|
||||
|
||||
if isinstance(prompt, PromptValue):
|
||||
return [_convert_message_to_dict(message) for message in prompt.to_messages()]
|
||||
|
||||
raise TypeError(f"Unsupported prompt type: {type(prompt)}")
|
||||
return [{"role": "user", "content": prompt}]
|
||||
return [msg.model_dump(exclude_none=True) for msg in prompt]
|
||||
|
||||
|
||||
def _prompt_as_json(
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
*,
|
||||
is_legacy_langchain: bool,
|
||||
) -> JSON_ro:
|
||||
prompt_payload = _prompt_to_dict(prompt) if is_legacy_langchain else prompt
|
||||
return cast(JSON_ro, prompt_payload)
|
||||
def _prompt_as_json(prompt: LanguageModelInput) -> JSON_ro:
|
||||
return cast(JSON_ro, _prompt_to_dicts(prompt))
|
||||
|
||||
|
||||
def _truncate_litellm_user_id(user_id: str) -> str:
|
||||
if len(user_id) <= MAX_LITELLM_USER_ID_LENGTH:
|
||||
return user_id
|
||||
logger.warning(
|
||||
"LLM user id exceeds %d chars (len=%d); truncating for provider compatibility.",
|
||||
MAX_LITELLM_USER_ID_LENGTH,
|
||||
len(user_id),
|
||||
)
|
||||
return user_id[:MAX_LITELLM_USER_ID_LENGTH]
|
||||
|
||||
|
||||
class LitellmLLM(LLM):
|
||||
@@ -371,18 +181,12 @@ class LitellmLLM(LLM):
|
||||
dump["credentials_file"] = mask_string(credentials_file)
|
||||
return dump
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self._safe_model_config()}")
|
||||
|
||||
def _record_call(
|
||||
self,
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
is_legacy_langchain: bool = False,
|
||||
prompt: LanguageModelInput,
|
||||
) -> None:
|
||||
if self._long_term_logger:
|
||||
prompt_json = _prompt_as_json(
|
||||
prompt, is_legacy_langchain=is_legacy_langchain
|
||||
)
|
||||
prompt_json = _prompt_as_json(prompt)
|
||||
self._long_term_logger.record(
|
||||
{
|
||||
"prompt": prompt_json,
|
||||
@@ -393,14 +197,11 @@ class LitellmLLM(LLM):
|
||||
|
||||
def _record_result(
|
||||
self,
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
prompt: LanguageModelInput,
|
||||
model_output: BaseMessage,
|
||||
is_legacy_langchain: bool,
|
||||
) -> None:
|
||||
if self._long_term_logger:
|
||||
prompt_json = _prompt_as_json(
|
||||
prompt, is_legacy_langchain=is_legacy_langchain
|
||||
)
|
||||
prompt_json = _prompt_as_json(prompt)
|
||||
tool_calls = (
|
||||
model_output.tool_calls if hasattr(model_output, "tool_calls") else []
|
||||
)
|
||||
@@ -416,14 +217,11 @@ class LitellmLLM(LLM):
|
||||
|
||||
def _record_error(
|
||||
self,
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
prompt: LanguageModelInput,
|
||||
error: Exception,
|
||||
is_legacy_langchain: bool,
|
||||
) -> None:
|
||||
if self._long_term_logger:
|
||||
prompt_json = _prompt_as_json(
|
||||
prompt, is_legacy_langchain=is_legacy_langchain
|
||||
)
|
||||
prompt_json = _prompt_as_json(prompt)
|
||||
self._long_term_logger.record(
|
||||
{
|
||||
"prompt": prompt_json,
|
||||
@@ -440,48 +238,27 @@ class LitellmLLM(LLM):
|
||||
|
||||
def _completion(
|
||||
self,
|
||||
prompt: LanguageModelInput | LangChainLanguageModelInput,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None,
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
parallel_tool_calls: bool,
|
||||
reasoning_effort: str | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
is_legacy_langchain: bool = False,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
processed_prompt: LegacyPromptDict | LanguageModelInput
|
||||
if is_legacy_langchain:
|
||||
processed_prompt = _prompt_to_dict(prompt)
|
||||
else:
|
||||
processed_prompt = cast(LanguageModelInput, prompt)
|
||||
|
||||
# Record the original prompt (not the processed one) for logging
|
||||
original_prompt = prompt
|
||||
self._record_call(original_prompt, is_legacy_langchain)
|
||||
self._record_call(prompt)
|
||||
from onyx.llm.litellm_singleton import litellm
|
||||
from litellm.exceptions import Timeout, RateLimitError
|
||||
|
||||
tool_choice_formatted: dict[str, Any] | str | None
|
||||
if not tools:
|
||||
tool_choice_formatted = None
|
||||
elif tool_choice and tool_choice not in STANDARD_TOOL_CHOICE_OPTIONS:
|
||||
tool_choice_formatted = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
else:
|
||||
tool_choice_formatted = tool_choice
|
||||
|
||||
is_reasoning = model_is_reasoning_model(
|
||||
self.config.model_name, self.config.model_provider
|
||||
)
|
||||
|
||||
# Needed to get reasoning tokens from the model
|
||||
if not is_legacy_langchain and (
|
||||
if (
|
||||
is_true_openai_model(self.config.model_provider, self.config.model_name)
|
||||
or self.config.model_provider == AZURE_PROVIDER_NAME
|
||||
):
|
||||
@@ -489,6 +266,29 @@ class LitellmLLM(LLM):
|
||||
else:
|
||||
model_provider = self.config.model_provider
|
||||
|
||||
completion_kwargs: dict[str, Any] = self._model_kwargs
|
||||
if SEND_USER_METADATA_TO_LLM_PROVIDER and user_identity:
|
||||
completion_kwargs = dict(self._model_kwargs)
|
||||
|
||||
if user_identity.user_id:
|
||||
completion_kwargs["user"] = _truncate_litellm_user_id(
|
||||
user_identity.user_id
|
||||
)
|
||||
|
||||
if user_identity.session_id:
|
||||
existing_metadata = completion_kwargs.get("metadata")
|
||||
metadata: dict[str, Any] | None
|
||||
if existing_metadata is None:
|
||||
metadata = {}
|
||||
elif isinstance(existing_metadata, dict):
|
||||
metadata = dict(existing_metadata)
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
if metadata is not None:
|
||||
metadata["session_id"] = user_identity.session_id
|
||||
completion_kwargs["metadata"] = metadata
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
@@ -502,9 +302,9 @@ class LitellmLLM(LLM):
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
# actual input
|
||||
messages=processed_prompt,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice_formatted,
|
||||
tool_choice=tool_choice if tools else None,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
@@ -532,8 +332,16 @@ class LitellmLLM(LLM):
|
||||
# Anthropic Claude uses `thinking` with budget_tokens for extended thinking
|
||||
# This applies to Claude models on any provider (anthropic, vertex_ai, bedrock)
|
||||
**(
|
||||
{"thinking": {"type": "enabled", "budget_tokens": 10000}}
|
||||
{
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": CLAUDE_REASONING_BUDGET_TOKENS[
|
||||
reasoning_effort
|
||||
],
|
||||
}
|
||||
}
|
||||
if reasoning_effort
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and is_reasoning
|
||||
and "claude" in self.config.model_name.lower()
|
||||
else {}
|
||||
@@ -541,8 +349,9 @@ class LitellmLLM(LLM):
|
||||
# OpenAI and other providers use reasoning_effort
|
||||
# (litellm maps this to thinking_level for Gemini 3 models)
|
||||
**(
|
||||
{"reasoning_effort": reasoning_effort}
|
||||
{"reasoning_effort": OPENAI_REASONING_EFFORT[reasoning_effort]}
|
||||
if reasoning_effort
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and is_reasoning
|
||||
and "claude" not in self.config.model_name.lower()
|
||||
else {}
|
||||
@@ -553,11 +362,11 @@ class LitellmLLM(LLM):
|
||||
else {}
|
||||
),
|
||||
**({self._max_token_param: max_tokens} if max_tokens else {}),
|
||||
**self._model_kwargs,
|
||||
**completion_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
self._record_error(original_prompt, e, is_legacy_langchain)
|
||||
self._record_error(prompt, e)
|
||||
# for break pointing
|
||||
if isinstance(e, Timeout):
|
||||
raise LLMTimeoutError(e)
|
||||
@@ -587,134 +396,7 @@ class LitellmLLM(LLM):
|
||||
max_input_tokens=self._max_input_tokens,
|
||||
)
|
||||
|
||||
def _invoke_implementation_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
from litellm import ModelResponse
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
ModelResponse,
|
||||
self._completion(
|
||||
is_legacy_langchain=True,
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=False,
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
output = _convert_litellm_message_to_langchain_message(choice.message)
|
||||
if output:
|
||||
self._record_result(prompt, output, is_legacy_langchain=True)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Unexpected response choice type")
|
||||
|
||||
def _stream_implementation_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
from litellm import CustomStreamWrapper
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke_langchain(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
CustomStreamWrapper,
|
||||
self._completion(
|
||||
is_legacy_langchain=True,
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=False,
|
||||
reasoning_effort="minimal",
|
||||
),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
if not part["choices"]:
|
||||
continue
|
||||
|
||||
choice = part["choices"][0]
|
||||
message_chunk = _convert_delta_to_message_chunk(
|
||||
choice["delta"],
|
||||
output,
|
||||
stop_reason=choice["finish_reason"],
|
||||
)
|
||||
|
||||
if output is None:
|
||||
output = message_chunk
|
||||
else:
|
||||
output += message_chunk
|
||||
|
||||
yield message_chunk
|
||||
|
||||
except RemoteProtocolError:
|
||||
raise RuntimeError(
|
||||
"The AI model failed partway through generation, please try again."
|
||||
)
|
||||
|
||||
if output:
|
||||
self._record_result(prompt, output, is_legacy_langchain=True)
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS and output:
|
||||
content = output.content or ""
|
||||
if isinstance(output, AIMessage):
|
||||
if content:
|
||||
log_msg = content
|
||||
elif output.tool_calls:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in output.tool_calls
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Raw Model Output:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Raw Model Output:\n{content}")
|
||||
|
||||
def _invoke_implementation(
|
||||
def invoke(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@@ -722,15 +404,13 @@ class LitellmLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: str | None = "medium",
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> ModelResponse:
|
||||
from litellm import ModelResponse as LiteLLMModelResponse
|
||||
|
||||
from onyx.llm.model_response import from_litellm_model_response
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
self._completion(
|
||||
@@ -743,12 +423,13 @@ class LitellmLLM(LLM):
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
),
|
||||
)
|
||||
|
||||
return from_litellm_model_response(response)
|
||||
|
||||
def _stream_implementation(
|
||||
def stream(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
@@ -756,14 +437,12 @@ class LitellmLLM(LLM):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: str | None = "medium",
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Iterator[ModelResponseStream]:
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
from onyx.llm.model_response import from_litellm_model_response_stream
|
||||
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
LiteLLMCustomStreamWrapper,
|
||||
self._completion(
|
||||
@@ -776,6 +455,7 @@ class LitellmLLM(LLM):
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -21,16 +21,21 @@ PROVIDER_DISPLAY_NAMES: dict[str, str] = {
|
||||
"deepseek": "DeepSeek",
|
||||
"xai": "xAI",
|
||||
"mistral": "Mistral",
|
||||
"mistralai": "Mistral", # Alias used by some providers
|
||||
"cohere": "Cohere",
|
||||
"perplexity": "Perplexity",
|
||||
"amazon": "Amazon",
|
||||
"meta": "Meta",
|
||||
"meta-llama": "Meta", # Alias used by some providers
|
||||
"ai21": "AI21",
|
||||
"nvidia": "NVIDIA",
|
||||
"databricks": "Databricks",
|
||||
"alibaba": "Alibaba",
|
||||
"qwen": "Qwen",
|
||||
"microsoft": "Microsoft",
|
||||
"gemini": "Gemini",
|
||||
"stability": "Stability",
|
||||
"writer": "Writer",
|
||||
}
|
||||
|
||||
# Map vendors to their brand names (used for provider_display_name generation)
|
||||
@@ -45,6 +50,11 @@ VENDOR_BRAND_NAMES: dict[str, str] = {
|
||||
"deepseek": "DeepSeek",
|
||||
"xai": "Grok",
|
||||
"perplexity": "Sonar",
|
||||
"ai21": "Jamba",
|
||||
"nvidia": "Nemotron",
|
||||
"qwen": "Qwen",
|
||||
"alibaba": "Qwen",
|
||||
"writer": "Palmyra",
|
||||
}
|
||||
|
||||
# Aggregator providers that host models from multiple vendors
|
||||
@@ -52,6 +62,155 @@ AGGREGATOR_PROVIDERS: set[str] = {
|
||||
"bedrock",
|
||||
"bedrock_converse",
|
||||
"openrouter",
|
||||
"ollama_chat",
|
||||
"vertex_ai",
|
||||
"azure",
|
||||
}
|
||||
|
||||
# Model family name mappings for display name generation
|
||||
# Used by Bedrock display name generator
|
||||
BEDROCK_MODEL_NAME_MAPPINGS: dict[str, str] = {
|
||||
"claude": "Claude",
|
||||
"llama": "Llama",
|
||||
"mistral": "Mistral",
|
||||
"mixtral": "Mixtral",
|
||||
"titan": "Titan",
|
||||
"nova": "Nova",
|
||||
"jamba": "Jamba",
|
||||
"command": "Command",
|
||||
"deepseek": "DeepSeek",
|
||||
}
|
||||
|
||||
# Used by Ollama display name generator
|
||||
OLLAMA_MODEL_NAME_MAPPINGS: dict[str, str] = {
|
||||
"llama": "Llama",
|
||||
"qwen": "Qwen",
|
||||
"mistral": "Mistral",
|
||||
"deepseek": "DeepSeek",
|
||||
"gemma": "Gemma",
|
||||
"phi": "Phi",
|
||||
"codellama": "Code Llama",
|
||||
"starcoder": "StarCoder",
|
||||
"wizardcoder": "WizardCoder",
|
||||
"vicuna": "Vicuna",
|
||||
"orca": "Orca",
|
||||
"dolphin": "Dolphin",
|
||||
"nous": "Nous",
|
||||
"neural": "Neural",
|
||||
"mixtral": "Mixtral",
|
||||
"falcon": "Falcon",
|
||||
"yi": "Yi",
|
||||
"command": "Command",
|
||||
"zephyr": "Zephyr",
|
||||
"openchat": "OpenChat",
|
||||
"solar": "Solar",
|
||||
}
|
||||
|
||||
# Bedrock model token limits (AWS doesn't expose this via API)
|
||||
# Note: Many Bedrock model IDs include context length suffix (e.g., ":200k")
|
||||
# which is parsed first. This mapping is for models without suffixes.
|
||||
# Sources:
|
||||
# - LiteLLM model_prices_and_context_window.json
|
||||
# - AWS Bedrock documentation and announcement blogs
|
||||
BEDROCK_MODEL_TOKEN_LIMITS: dict[str, int] = {
|
||||
# Anthropic Claude models (new naming: claude-{tier}-{version})
|
||||
"claude-opus-4": 200000,
|
||||
"claude-sonnet-4": 200000,
|
||||
"claude-haiku-4": 200000,
|
||||
# Anthropic Claude models (old naming: claude-{version})
|
||||
"claude-4": 200000,
|
||||
"claude-3-7": 200000,
|
||||
"claude-3-5": 200000,
|
||||
"claude-3": 200000,
|
||||
"claude-v2": 100000,
|
||||
"claude-instant": 100000,
|
||||
# Amazon Nova models (from LiteLLM)
|
||||
"nova-premier": 1000000,
|
||||
"nova-pro": 300000,
|
||||
"nova-lite": 300000,
|
||||
"nova-2-lite": 1000000, # Nova 2 Lite has 1M context
|
||||
"nova-2-sonic": 128000,
|
||||
"nova-micro": 128000,
|
||||
# Amazon Titan models (from LiteLLM: all text models are 42K)
|
||||
"titan-text-premier": 42000,
|
||||
"titan-text-express": 42000,
|
||||
"titan-text-lite": 42000,
|
||||
"titan-tg1": 8000,
|
||||
# Meta Llama models (Llama 3 base = 8K, Llama 3.1+ = 128K)
|
||||
"llama4": 128000,
|
||||
"llama3-3": 128000,
|
||||
"llama3-2": 128000,
|
||||
"llama3-1": 128000,
|
||||
"llama3-8b": 8000,
|
||||
"llama3-70b": 8000,
|
||||
# Mistral models (Large 2+ = 128K, original Large/Small = 32K)
|
||||
"mistral-large-3": 128000,
|
||||
"mistral-large-2407": 128000, # Mistral Large 2
|
||||
"mistral-large-2402": 32000, # Original Mistral Large
|
||||
"mistral-large": 128000, # Default to newer version
|
||||
"mistral-small": 32000,
|
||||
"mistral-7b": 32000,
|
||||
"mixtral-8x7b": 32000,
|
||||
"pixtral": 128000,
|
||||
"ministral": 128000,
|
||||
"magistral": 128000,
|
||||
"voxtral": 32000,
|
||||
# Cohere models
|
||||
"command-r-plus": 128000,
|
||||
"command-r": 128000,
|
||||
# DeepSeek models
|
||||
"deepseek": 64000,
|
||||
# Google Gemma models
|
||||
"gemma-3": 128000,
|
||||
"gemma-2": 8000,
|
||||
"gemma": 8000,
|
||||
# Qwen models
|
||||
"qwen3": 128000,
|
||||
"qwen2": 128000,
|
||||
# NVIDIA models
|
||||
"nemotron": 128000,
|
||||
# Writer Palmyra models
|
||||
"palmyra": 128000,
|
||||
# Moonshot Kimi
|
||||
"kimi": 128000,
|
||||
# Minimax
|
||||
"minimax": 128000,
|
||||
# OpenAI (via Bedrock)
|
||||
"gpt-oss": 128000,
|
||||
# AI21 models (from LiteLLM: Jamba 1.5 = 256K, Jamba Instruct = 70K)
|
||||
"jamba-1-5": 256000,
|
||||
"jamba-instruct": 70000,
|
||||
"jamba": 256000, # Default to newer version
|
||||
}
|
||||
|
||||
|
||||
# Ollama model prefix to vendor mapping (for grouping models by vendor)
|
||||
OLLAMA_MODEL_TO_VENDOR: dict[str, str] = {
|
||||
"llama": "Meta",
|
||||
"codellama": "Meta",
|
||||
"qwen": "Alibaba",
|
||||
"qwq": "Alibaba",
|
||||
"mistral": "Mistral",
|
||||
"ministral": "Mistral",
|
||||
"mixtral": "Mistral",
|
||||
"deepseek": "DeepSeek",
|
||||
"gemma": "Google",
|
||||
"phi": "Microsoft",
|
||||
"command": "Cohere",
|
||||
"aya": "Cohere",
|
||||
"falcon": "TII",
|
||||
"yi": "01.AI",
|
||||
"starcoder": "BigCode",
|
||||
"wizardcoder": "WizardLM",
|
||||
"vicuna": "LMSYS",
|
||||
"openchat": "OpenChat",
|
||||
"solar": "Upstage",
|
||||
"orca": "Microsoft",
|
||||
"dolphin": "Cognitive Computations",
|
||||
"nous": "Nous Research",
|
||||
"neural": "Intel",
|
||||
"zephyr": "HuggingFace",
|
||||
"granite": "IBM",
|
||||
"nemotron": "NVIDIA",
|
||||
"smollm": "HuggingFace",
|
||||
}
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
class GenAIDisabledException(Exception):
|
||||
def __init__(self, message: str = "Generative AI has been turned off") -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
@@ -15,7 +15,6 @@ from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.chat_llm import LitellmLLM
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
@@ -109,7 +108,7 @@ def get_llm_config_for_persona(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | None,
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
user: User | None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
@@ -136,18 +135,22 @@ def get_llms_for_persona(
|
||||
if not provider_model:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
# Only check access control for database Persona entities, not PersonaOverrideConfig
|
||||
# PersonaOverrideConfig is used for temporary overrides and doesn't have access restrictions
|
||||
persona_model = persona if isinstance(persona, Persona) else None
|
||||
|
||||
# Fetch user group IDs for access control check
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
|
||||
if not can_user_access_llm_provider(
|
||||
provider_model,
|
||||
user_group_ids,
|
||||
persona,
|
||||
persona_model,
|
||||
):
|
||||
logger.warning(
|
||||
"User %s with persona %s cannot access provider %s. Falling back to default provider.",
|
||||
getattr(user, "id", None),
|
||||
getattr(persona, "id", None),
|
||||
getattr(persona_model, "id", None),
|
||||
provider_model.name,
|
||||
)
|
||||
return get_default_llms(
|
||||
@@ -197,8 +200,6 @@ def get_default_llm_with_vision(
|
||||
|
||||
Returns None if no providers exist or if no provider supports images.
|
||||
"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
@@ -316,9 +317,6 @@ def get_default_llms(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
|
||||
@@ -1,30 +1,22 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
from typing import Union
|
||||
|
||||
from braintrust import traced
|
||||
from langchain.schema.language_model import (
|
||||
LanguageModelInput as LangChainLanguageModelInput,
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_INDIVIDUAL_MODEL_TOKENS
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import ToolChoiceOptions
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
STANDARD_TOOL_CHOICE_OPTIONS = ("required", "auto", "none")
|
||||
ToolChoiceOptions = Union[Literal["required", "auto", "none"], str]
|
||||
LanguageModelInput = Union[Sequence[ChatCompletionMessage], str]
|
||||
|
||||
class LLMUserIdentity(BaseModel):
|
||||
user_id: str | None = None
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
@@ -41,60 +33,12 @@ class LLMConfig(BaseModel):
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
def log_prompt(prompt: LangChainLanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
if msg.content:
|
||||
log_msg = msg.content
|
||||
elif msg.tool_call_chunks:
|
||||
log_msg = "Tool Calls: " + str(
|
||||
[
|
||||
{
|
||||
key: value
|
||||
for key, value in tool_call.items()
|
||||
if key != "index"
|
||||
}
|
||||
for tool_call in msg.tool_call_chunks
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = ""
|
||||
logger.debug(f"Message {ind}:\n{log_msg}")
|
||||
else:
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||
to use these implementations to connect to a variety of LLM providers."""
|
||||
|
||||
@property
|
||||
def requires_warm_up(self) -> bool:
|
||||
"""Is this model running in memory and needs an initial call to warm it up?"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def config(self) -> LLMConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _precall(self, prompt: LangChainLanguageModelInput) -> None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise Exception("Generative AI is disabled")
|
||||
if LOG_ONYX_MODEL_INTERACTIONS:
|
||||
log_prompt(prompt)
|
||||
|
||||
@traced(name="invoke llm", type="llm")
|
||||
def invoke(
|
||||
self,
|
||||
@@ -104,72 +48,9 @@ class LLM(abc.ABC):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> "ModelResponse":
|
||||
return self._invoke_implementation(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
@traced(name="invoke llm", type="llm")
|
||||
def invoke_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation_langchain(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> "ModelResponse":
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
self,
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[ModelResponseStream]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
def stream(
|
||||
@@ -180,54 +61,7 @@ class LLM(abc.ABC):
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
reasoning_effort: ReasoningEffort | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
) -> Iterator[ModelResponseStream]:
|
||||
return self._stream_implementation(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
def stream_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation_langchain(
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
tokens = []
|
||||
for message in messages:
|
||||
if LOG_INDIVIDUAL_MODEL_TOKENS:
|
||||
tokens.append(message.content)
|
||||
yield message
|
||||
|
||||
if LOG_INDIVIDUAL_MODEL_TOKENS and tokens:
|
||||
logger.debug(f"Model Tokens: {tokens}")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation_langchain(
|
||||
self,
|
||||
prompt: LangChainLanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -606,6 +606,56 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_openai_responses_tool_content_type() -> None:
|
||||
"""
|
||||
Patches LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text
|
||||
to use 'input_text' type for tool messages instead of 'output_text'.
|
||||
|
||||
The OpenAI Responses API only accepts 'input_text', 'input_image', and 'input_file'
|
||||
in the function_call_output.output array. The default litellm implementation
|
||||
incorrectly uses 'output_text' for tool messages, causing 400 Bad Request errors.
|
||||
|
||||
See: https://github.com/BerriAI/litellm/issues/17507
|
||||
|
||||
This should be removed once litellm releases a fix for this issue.
|
||||
"""
|
||||
original_method = (
|
||||
LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(
|
||||
original_method,
|
||||
"__name__",
|
||||
"",
|
||||
)
|
||||
== "_patched_convert_content_str_to_input_text"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_convert_content_str_to_input_text(
|
||||
self: Any, content: str, role: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert string content to the appropriate Responses API format.
|
||||
|
||||
For user, system, and tool messages, use 'input_text' type.
|
||||
For assistant messages, use 'output_text' type.
|
||||
|
||||
Tool messages go into function_call_output.output, which only accepts
|
||||
'input_text', 'input_image', and 'input_file' types.
|
||||
"""
|
||||
if role in ("user", "system", "tool"):
|
||||
return {"type": "input_text", "text": content}
|
||||
else:
|
||||
return {"type": "output_text", "text": content}
|
||||
|
||||
_patched_convert_content_str_to_input_text.__name__ = (
|
||||
"_patched_convert_content_str_to_input_text"
|
||||
)
|
||||
LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text = _patched_convert_content_str_to_input_text # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -615,11 +665,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_openai_responses_tool_content_type()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -56,6 +56,15 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
# Curated list of OpenAI models to show by default in the UI
|
||||
OPENAI_VISIBLE_MODEL_NAMES = {
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
BEDROCK_DEFAULT_MODEL = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
@@ -125,6 +134,12 @@ _IGNORABLE_ANTHROPIC_MODELS = {
|
||||
"claude-instant-1",
|
||||
"anthropic/claude-3-5-sonnet-20241022",
|
||||
}
|
||||
# Curated list of Anthropic models to show by default in the UI
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = {
|
||||
"claude-opus-4-5",
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
@@ -134,53 +149,113 @@ VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.5-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.5-flash-lite"
|
||||
# Curated list of Vertex AI models to show by default in the UI
|
||||
VERTEXAI_VISIBLE_MODEL_NAMES = {
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-pro",
|
||||
}
|
||||
|
||||
|
||||
def is_obsolete_model(model_name: str, provider: str) -> bool:
|
||||
"""Check if a model is obsolete and should be filtered out.
|
||||
|
||||
Filters models that are 2+ major versions behind or deprecated.
|
||||
This is the single source of truth for obsolete model detection.
|
||||
"""
|
||||
model_lower = model_name.lower()
|
||||
|
||||
# OpenAI obsolete models
|
||||
if provider == "openai":
|
||||
# GPT-3 models are obsolete
|
||||
if "gpt-3" in model_lower:
|
||||
return True
|
||||
# Legacy models
|
||||
deprecated = {
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
"davinci",
|
||||
"curie",
|
||||
"babbage",
|
||||
"ada",
|
||||
}
|
||||
if model_lower in deprecated:
|
||||
return True
|
||||
|
||||
# Anthropic obsolete models
|
||||
if provider == "anthropic":
|
||||
if "claude-2" in model_lower or "claude-instant" in model_lower:
|
||||
return True
|
||||
|
||||
# Vertex AI obsolete models
|
||||
if provider == "vertex_ai":
|
||||
if "gemini-1.0" in model_lower:
|
||||
return True
|
||||
if "palm" in model_lower or "bison" in model_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level."""
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level.
|
||||
|
||||
Dynamic providers (Bedrock, Ollama, OpenRouter) return empty lists here
|
||||
because their models are fetched directly from the source API, which is
|
||||
more up-to-date than LiteLLM's static lists.
|
||||
"""
|
||||
return {
|
||||
OPENAI_PROVIDER_NAME: get_openai_model_names(),
|
||||
BEDROCK_PROVIDER_NAME: get_bedrock_model_names(),
|
||||
BEDROCK_PROVIDER_NAME: [], # Dynamic - fetched from AWS API
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
OPENROUTER_PROVIDER_NAME: get_openrouter_model_names(),
|
||||
OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API
|
||||
OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API
|
||||
}
|
||||
|
||||
|
||||
def get_openai_model_names() -> list[str]:
|
||||
"""Get OpenAI model names dynamically from litellm."""
|
||||
import re
|
||||
import litellm
|
||||
|
||||
# TODO: remove these lists once we have a comprehensive model configuration page
|
||||
# The ideal flow should be: fetch all available models --> filter by type
|
||||
# --> allow user to modify filters and select models based on current context
|
||||
non_chat_model_terms = {
|
||||
"embed",
|
||||
"audio",
|
||||
"tts",
|
||||
"whisper",
|
||||
"dall-e",
|
||||
"image",
|
||||
"moderation",
|
||||
"sora",
|
||||
"container",
|
||||
}
|
||||
deprecated_model_terms = {"babbage", "davinci", "gpt-3.5", "gpt-4-"}
|
||||
excluded_terms = non_chat_model_terms | deprecated_model_terms
|
||||
|
||||
# NOTE: We are explicitly excluding all "timestamped" models
|
||||
# because they are mostly just noise in the admin configuration panel
|
||||
# e.g. gpt-4o-2025-07-16, gpt-3.5-turbo-0613, etc.
|
||||
date_pattern = re.compile(r"-\d{4}")
|
||||
|
||||
def is_valid_model(model: str) -> bool:
|
||||
model_lower = model.lower()
|
||||
return not any(
|
||||
ex in model_lower for ex in excluded_terms
|
||||
) and not date_pattern.search(model)
|
||||
|
||||
return sorted(
|
||||
[
|
||||
# Strip openai/ prefix if present
|
||||
model.replace("openai/", "") if model.startswith("openai/") else model
|
||||
(
|
||||
model.removeprefix("openai/")
|
||||
for model in litellm.open_ai_chat_completion_models
|
||||
if "embed" not in model.lower()
|
||||
and "audio" not in model.lower()
|
||||
and "tts" not in model.lower()
|
||||
and "whisper" not in model.lower()
|
||||
and "dall-e" not in model.lower()
|
||||
and "moderation" not in model.lower()
|
||||
and "sora" not in model.lower() # video generation
|
||||
and "container" not in model.lower() # not a model
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def get_bedrock_model_names() -> list[str]:
|
||||
"""Get Bedrock model names dynamically from litellm."""
|
||||
import litellm
|
||||
|
||||
# bedrock_converse_models are just extensions of the bedrock_models
|
||||
return sorted(
|
||||
[
|
||||
model
|
||||
for model in litellm.bedrock_models.union(litellm.bedrock_converse_models)
|
||||
if "/" not in model and "embed" not in model.lower()
|
||||
],
|
||||
if is_valid_model(model)
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
@@ -194,6 +269,7 @@ def get_anthropic_model_names() -> list[str]:
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in _IGNORABLE_ANTHROPIC_MODELS
|
||||
and not is_obsolete_model(model, ANTHROPIC_PROVIDER_NAME)
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
@@ -239,21 +315,12 @@ def get_vertexai_model_names() -> list[str]:
|
||||
and "/" not in model # filter out prefixed models like openai/gpt-oss
|
||||
and "search_api" not in model.lower() # not a model
|
||||
and "-maas" not in model.lower() # marketplace models
|
||||
and not is_obsolete_model(model, VERTEXAI_PROVIDER_NAME)
|
||||
],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def get_openrouter_model_names() -> list[str]:
|
||||
"""Get OpenRouter model names dynamically from litellm."""
|
||||
import litellm
|
||||
|
||||
return sorted(
|
||||
[model for model in litellm.openrouter_models if "embed" not in model.lower()],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
return [
|
||||
WellKnownLLMProviderDescriptor(
|
||||
@@ -413,7 +480,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
name=VERTEX_LOCATION_KWARG,
|
||||
display_name="Location",
|
||||
description="The location of the Vertex AI model. Please refer to the "
|
||||
"[Vertex AI configuration docs](https://docs.onyx.app/admin/ai_models/google_ai) for all possible values.",
|
||||
"[Vertex AI configuration docs](https://docs.onyx.app/admins/ai_models/google_ai) for all possible values.",
|
||||
is_required=False,
|
||||
is_secret=False,
|
||||
key_type=CustomConfigKeyType.TEXT_INPUT,
|
||||
@@ -488,20 +555,46 @@ def get_provider_display_name(provider_name: str) -> str:
|
||||
)
|
||||
|
||||
|
||||
def _get_visible_models_for_provider(provider_name: str) -> set[str]:
|
||||
"""Get the set of models that should be visible by default for a provider."""
|
||||
_PROVIDER_TO_VISIBLE_MODELS: dict[str, set[str]] = {
|
||||
OPENAI_PROVIDER_NAME: OPENAI_VISIBLE_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
}
|
||||
return _PROVIDER_TO_VISIBLE_MODELS.get(provider_name, set())
|
||||
|
||||
|
||||
def fetch_model_configurations_for_provider(
|
||||
provider_name: str,
|
||||
) -> list[ModelConfigurationView]:
|
||||
# No models are marked visible by default - the default model logic
|
||||
# in the frontend/backend will handle making default models visible.
|
||||
return [
|
||||
ModelConfigurationView(
|
||||
name=model_name,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
supports_image_input=model_supports_image_input(
|
||||
model_name=model_name,
|
||||
model_provider=provider_name,
|
||||
),
|
||||
"""Fetch model configurations for a static provider (OpenAI, Anthropic, Vertex AI).
|
||||
|
||||
Looks up max_input_tokens from LiteLLM's model_cost. If not found, stores None
|
||||
and the runtime will use the fallback (32000).
|
||||
|
||||
Models in the curated visible lists (OPENAI_VISIBLE_MODEL_NAMES, etc.) are
|
||||
marked as is_visible=True by default.
|
||||
"""
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
|
||||
visible_models = _get_visible_models_for_provider(provider_name)
|
||||
configs = []
|
||||
for model_name in fetch_models_for_provider(provider_name):
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=model_name,
|
||||
model_provider=provider_name,
|
||||
)
|
||||
for model_name in fetch_models_for_provider(provider_name)
|
||||
]
|
||||
|
||||
configs.append(
|
||||
ModelConfigurationView(
|
||||
name=model_name,
|
||||
is_visible=model_name in visible_models,
|
||||
max_input_tokens=max_input_tokens,
|
||||
supports_image_input=model_supports_image_input(
|
||||
model_name=model_name,
|
||||
model_provider=provider_name,
|
||||
),
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
class TextContentPart(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ImageUrlDetail(TypedDict):
|
||||
url: str
|
||||
detail: NotRequired[Literal["auto", "low", "high"]]
|
||||
|
||||
|
||||
class ImageContentPart(TypedDict):
|
||||
type: Literal["image_url"]
|
||||
image_url: ImageUrlDetail
|
||||
|
||||
|
||||
ContentPart = TextContentPart | ImageContentPart
|
||||
|
||||
|
||||
# Tool call structures
|
||||
class FunctionCall(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
# Message types
|
||||
class SystemMessage(TypedDict):
|
||||
role: Literal["system"]
|
||||
content: str
|
||||
|
||||
|
||||
class UserMessageWithText(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: str
|
||||
|
||||
|
||||
class UserMessageWithParts(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: list[ContentPart]
|
||||
|
||||
|
||||
UserMessage = UserMessageWithText | UserMessageWithParts
|
||||
|
||||
|
||||
class AssistantMessage(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: NotRequired[str | None]
|
||||
tool_calls: NotRequired[list[ToolCall]]
|
||||
|
||||
|
||||
class ToolMessage(TypedDict):
|
||||
role: Literal["tool"]
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
# Union type for all OpenAI Chat Completions messages
|
||||
ChatCompletionMessage = SystemMessage | UserMessage | AssistantMessage | ToolMessage
|
||||
@@ -2621,6 +2621,28 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-10-06"
|
||||
},
|
||||
"gpt-5.2-pro-2025-12-11": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-12-11"
|
||||
},
|
||||
"gpt-5.2-pro": {
|
||||
"display_name": "GPT-5.2 Pro",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-chat-latest": {
|
||||
"display_name": "GPT 5.2 Chat",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.2-2025-12-11": {
|
||||
"display_name": "GPT 5.2",
|
||||
"model_vendor": "openai",
|
||||
"model_version": "2025-12-11"
|
||||
},
|
||||
"gpt-5.2": {
|
||||
"display_name": "GPT 5.2",
|
||||
"model_vendor": "openai"
|
||||
},
|
||||
"gpt-5.1": {
|
||||
"display_name": "GPT 5.1",
|
||||
"model_vendor": "openai"
|
||||
|
||||
104
backend/onyx/llm/models.py
Normal file
104
backend/onyx/llm/models.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolChoiceOptions(str, Enum):
|
||||
REQUIRED = "required"
|
||||
AUTO = "auto"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class ReasoningEffort(str, Enum):
|
||||
"""Reasoning effort levels for models that support extended thinking.
|
||||
|
||||
Different providers map these values differently:
|
||||
- OpenAI: Uses "low", "medium", "high" directly for reasoning_effort. Recently added "none" for 5 series
|
||||
which is like "minimal"
|
||||
- Claude: Uses budget_tokens with different values for each level
|
||||
- Gemini: Uses "none", "low", "medium", "high" for thinking_budget (via litellm mapping)
|
||||
"""
|
||||
|
||||
OFF = "off"
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
|
||||
|
||||
# Budget tokens for Claude extended thinking at each reasoning effort level
|
||||
CLAUDE_REASONING_BUDGET_TOKENS: dict[ReasoningEffort, int] = {
|
||||
ReasoningEffort.OFF: 0,
|
||||
ReasoningEffort.LOW: 1000,
|
||||
ReasoningEffort.MEDIUM: 5000,
|
||||
ReasoningEffort.HIGH: 10000,
|
||||
}
|
||||
|
||||
# OpenAI reasoning effort mapping (direct string values)
|
||||
OPENAI_REASONING_EFFORT: dict[ReasoningEffort, str] = {
|
||||
ReasoningEffort.OFF: "none", # this only works for the 5 series though
|
||||
ReasoningEffort.LOW: "low",
|
||||
ReasoningEffort.MEDIUM: "medium",
|
||||
ReasoningEffort.HIGH: "high",
|
||||
}
|
||||
|
||||
|
||||
# Content part structures for multimodal messages
|
||||
# The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM
|
||||
class TextContentPart(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageUrlDetail(BaseModel):
|
||||
url: str
|
||||
detail: Literal["auto", "low", "high"] | None = None
|
||||
|
||||
|
||||
class ImageContentPart(BaseModel):
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: ImageUrlDetail
|
||||
|
||||
|
||||
ContentPart = TextContentPart | ImageContentPart
|
||||
|
||||
|
||||
# Tool call structures
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
id: str
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
# Message types
|
||||
class SystemMessage(BaseModel):
|
||||
role: Literal["system"] = "system"
|
||||
content: str
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
role: Literal["user"] = "user"
|
||||
content: str | list[ContentPart]
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class ToolMessage(BaseModel):
|
||||
role: Literal["tool"] = "tool"
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
# Union type for all OpenAI Chat Completions messages
|
||||
ChatCompletionMessage = SystemMessage | UserMessage | AssistantMessage | ToolMessage
|
||||
# Allows for passing in a string directly. This is provided for convenience and is wrapped as a UserMessage.
|
||||
LanguageModelInput = list[ChatCompletionMessage] | str
|
||||
@@ -6,10 +6,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
@@ -23,6 +19,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_TOKEN_ESTIMATE
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -88,7 +85,15 @@ def litellm_exception_to_error_msg(
|
||||
custom_error_msg_mappings: (
|
||||
dict[str, str] | None
|
||||
) = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
||||
) -> str:
|
||||
) -> tuple[str, str, bool]:
|
||||
"""Convert a LiteLLM exception to a user-friendly error message with classification.
|
||||
|
||||
Returns:
|
||||
tuple: (error_message, error_code, is_retryable)
|
||||
- error_message: User-friendly error description
|
||||
- error_code: Categorized error code for frontend display
|
||||
- is_retryable: Whether the user should try again
|
||||
"""
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from litellm.exceptions import PermissionDeniedError
|
||||
@@ -105,25 +110,37 @@ def litellm_exception_to_error_msg(
|
||||
|
||||
core_exception = _unwrap_nested_exception(e)
|
||||
error_msg = str(core_exception)
|
||||
error_code = "UNKNOWN_ERROR"
|
||||
is_retryable = True
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
return custom_error_msg, "CUSTOM_ERROR", True
|
||||
|
||||
if isinstance(core_exception, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
error_code = "BAD_REQUEST"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, AuthenticationError):
|
||||
error_msg = "Authentication failed: Please check your API key and credentials."
|
||||
error_code = "AUTH_ERROR"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, PermissionDeniedError):
|
||||
error_msg = (
|
||||
"Permission denied: You don't have the necessary permissions for this operation."
|
||||
"Permission denied: You don't have the necessary permissions for this operation. "
|
||||
"Ensure you have access to this model."
|
||||
)
|
||||
error_code = "PERMISSION_DENIED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, NotFoundError):
|
||||
error_msg = "Resource not found: The requested resource doesn't exist."
|
||||
error_code = "NOT_FOUND"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, UnprocessableEntityError):
|
||||
error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors."
|
||||
error_code = "UNPROCESSABLE_ENTITY"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, RateLimitError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
@@ -154,6 +171,8 @@ def litellm_exception_to_error_msg(
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
error_code = "RATE_LIMIT"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, ServiceUnavailableError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
@@ -171,6 +190,8 @@ def litellm_exception_to_error_msg(
|
||||
else:
|
||||
# Generic 503 Service Unavailable
|
||||
error_msg = f"{provider_name} service error: {str(core_exception)}"
|
||||
error_code = "SERVICE_UNAVAILABLE"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, ContextWindowExceededError):
|
||||
error_msg = (
|
||||
"Context window exceeded: Your input is too long for the model to process."
|
||||
@@ -181,58 +202,51 @@ def litellm_exception_to_error_msg(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
error_msg += f"Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}"
|
||||
error_msg += f" Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}."
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Unable to get maximum input token for LiteLLM excpetion handling"
|
||||
"Unable to get maximum input token for LiteLLM exception handling"
|
||||
)
|
||||
error_code = "CONTEXT_TOO_LONG"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, ContentPolicyViolationError):
|
||||
error_msg = "Content policy violation: Your request violates the content policy. Please revise your input."
|
||||
error_code = "CONTENT_POLICY"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, APIConnectionError):
|
||||
error_msg = "API connection error: Failed to connect to the API. Please check your internet connection."
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, BudgetExceededError):
|
||||
error_msg = (
|
||||
"Budget exceeded: You've exceeded your allocated budget for API usage."
|
||||
)
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
elif isinstance(core_exception, Timeout):
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
error_code = "CONNECTION_ERROR"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, APIError):
|
||||
error_msg = (
|
||||
"API error: An error occurred while communicating with the API. "
|
||||
f"Details: {str(core_exception)}"
|
||||
)
|
||||
error_code = "API_ERROR"
|
||||
is_retryable = True
|
||||
elif not fallback_to_error_msg:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
return error_msg
|
||||
error_code = "UNKNOWN_ERROR"
|
||||
is_retryable = True
|
||||
|
||||
return error_msg, error_code, is_retryable
|
||||
|
||||
|
||||
def dict_based_prompt_to_langchain_prompt(
|
||||
messages: list[dict[str, str]],
|
||||
) -> list[BaseMessage]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for message in messages:
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
if not role:
|
||||
raise ValueError(f"Message missing `role`: {message}")
|
||||
if not content:
|
||||
raise ValueError(f"Message missing `content`: {message}")
|
||||
elif role == "user":
|
||||
prompt.append(HumanMessage(content=content))
|
||||
elif role == "system":
|
||||
prompt.append(SystemMessage(content=content))
|
||||
elif role == "assistant":
|
||||
prompt.append(AIMessage(content=content))
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
return prompt
|
||||
|
||||
|
||||
def message_to_string(message: BaseMessage) -> str:
|
||||
if not isinstance(message.content, str):
|
||||
def llm_response_to_string(message: ModelResponse) -> str:
|
||||
if not isinstance(message.choice.message.content, str):
|
||||
raise RuntimeError("LLM message not in expected format.")
|
||||
|
||||
return message.content
|
||||
return message.choice.message.content
|
||||
|
||||
|
||||
def check_number_of_tokens(
|
||||
@@ -255,7 +269,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke_langchain("Do not respond")
|
||||
llm.invoke("Do not respond")
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -432,77 +446,74 @@ def get_llm_contextual_cost(
|
||||
return usd_per_prompt + usd_per_completion
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
def llm_max_input_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
"""Best effort attempt to get the max input tokens for the LLM."""
|
||||
if GEN_AI_MAX_TOKENS:
|
||||
# This is an override, so always return this
|
||||
logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}")
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
model_obj = find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_input_tokens"]
|
||||
return max_tokens
|
||||
|
||||
if "max_tokens" in model_obj:
|
||||
max_tokens = model_obj["max_tokens"]
|
||||
return max_tokens
|
||||
|
||||
logger.error(f"No max tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max tokens found for LLM")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS}."
|
||||
model_obj = find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
logger.warning(
|
||||
f"Model '{model_name}' not found in LiteLLM. "
|
||||
f"Falling back to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS} tokens."
|
||||
)
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
return model_obj["max_input_tokens"]
|
||||
|
||||
if "max_tokens" in model_obj:
|
||||
return model_obj["max_tokens"]
|
||||
|
||||
logger.warning(
|
||||
f"No max tokens found for '{model_name}'. "
|
||||
f"Falling back to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS} tokens."
|
||||
)
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
|
||||
def get_llm_max_output_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max output tokens for the LLM"""
|
||||
try:
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
else:
|
||||
pass
|
||||
"""Best effort attempt to get the max output tokens for the LLM."""
|
||||
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||
|
||||
if "max_output_tokens" in model_obj:
|
||||
max_output_tokens = model_obj["max_output_tokens"]
|
||||
return max_output_tokens
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
|
||||
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||
if "max_tokens" in model_obj:
|
||||
max_output_tokens = int(model_obj["max_tokens"] * 0.1)
|
||||
return max_output_tokens
|
||||
|
||||
logger.error(f"No max output tokens found for LLM: {model_name}")
|
||||
raise RuntimeError("No max output tokens found for LLM")
|
||||
except Exception:
|
||||
default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS)
|
||||
logger.exception(
|
||||
f"Failed to get max output tokens for LLM with name {model_name}. "
|
||||
f"Defaulting to {default_output_tokens} (fallback max tokens)."
|
||||
if not model_obj:
|
||||
logger.warning(
|
||||
f"Model '{model_name}' not found in LiteLLM. "
|
||||
f"Falling back to {default_output_tokens} output tokens."
|
||||
)
|
||||
return default_output_tokens
|
||||
|
||||
if "max_output_tokens" in model_obj:
|
||||
return model_obj["max_output_tokens"]
|
||||
|
||||
# Fallback to a fraction of max_tokens if max_output_tokens is not specified
|
||||
if "max_tokens" in model_obj:
|
||||
return int(model_obj["max_tokens"] * 0.1)
|
||||
|
||||
logger.warning(
|
||||
f"No max output tokens found for '{model_name}'. "
|
||||
f"Falling back to {default_output_tokens} output tokens."
|
||||
)
|
||||
return default_output_tokens
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str,
|
||||
@@ -518,7 +529,7 @@ def get_max_input_tokens(
|
||||
litellm_model_map = get_model_map()
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
llm_max_input_tokens(
|
||||
model_name=model_name,
|
||||
model_provider=model_provider,
|
||||
model_map=litellm_model_map,
|
||||
@@ -536,6 +547,19 @@ def get_max_input_tokens_from_llm_provider(
|
||||
llm_provider: "LLMProviderView",
|
||||
model_name: str,
|
||||
) -> int:
|
||||
"""Get max input tokens for a model, with fallback chain.
|
||||
|
||||
Fallback order:
|
||||
1. Use max_input_tokens from model_configuration (populated from source APIs
|
||||
like OpenRouter, Ollama, or our Bedrock mapping)
|
||||
2. Look up in litellm.model_cost dictionary
|
||||
3. Fall back to GEN_AI_MODEL_FALLBACK_MAX_TOKENS (32000)
|
||||
|
||||
Most dynamic providers (OpenRouter, Ollama) provide context_length via their
|
||||
APIs. Bedrock doesn't expose this, so we parse from model ID suffix (:200k)
|
||||
or use BEDROCK_MODEL_TOKEN_LIMITS mapping. The 32000 fallback is only hit for
|
||||
unknown models not in any of these sources.
|
||||
"""
|
||||
max_input_tokens = None
|
||||
for model_configuration in llm_provider.model_configurations:
|
||||
if model_configuration.name == model_name:
|
||||
@@ -550,6 +574,54 @@ def get_max_input_tokens_from_llm_provider(
|
||||
)
|
||||
|
||||
|
||||
def get_bedrock_token_limit(model_id: str) -> int:
|
||||
"""Look up token limit for a Bedrock model.
|
||||
|
||||
AWS Bedrock API doesn't expose token limits directly. This function
|
||||
attempts to determine the limit from multiple sources.
|
||||
|
||||
Lookup order:
|
||||
1. Parse from model ID suffix (e.g., ":200k" → 200000)
|
||||
2. Check LiteLLM's model_cost dictionary
|
||||
3. Fall back to our hardcoded BEDROCK_MODEL_TOKEN_LIMITS mapping
|
||||
4. Default to 32000 if not found anywhere
|
||||
"""
|
||||
from onyx.llm.constants import BEDROCK_MODEL_TOKEN_LIMITS
|
||||
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
# 1. Try to parse context length from model ID suffix
|
||||
# Format: "model-name:version:NNNk" where NNN is the context length in thousands
|
||||
# Examples: ":200k", ":128k", ":1000k", ":8k", ":4k"
|
||||
context_match = re.search(r":(\d+)k\b", model_id_lower)
|
||||
if context_match:
|
||||
return int(context_match.group(1)) * 1000
|
||||
|
||||
# 2. Check LiteLLM's model_cost dictionary
|
||||
try:
|
||||
model_map = get_model_map()
|
||||
# Try with bedrock/ prefix first, then without
|
||||
for key in [f"bedrock/{model_id}", model_id]:
|
||||
if key in model_map:
|
||||
model_info = model_map[key]
|
||||
if "max_input_tokens" in model_info:
|
||||
return model_info["max_input_tokens"]
|
||||
if "max_tokens" in model_info:
|
||||
return model_info["max_tokens"]
|
||||
except Exception:
|
||||
pass # Fall through to mapping
|
||||
|
||||
# 3. Try our hardcoded mapping (longest match first)
|
||||
for pattern, limit in sorted(
|
||||
BEDROCK_MODEL_TOKEN_LIMITS.items(), key=lambda x: -len(x[0])
|
||||
):
|
||||
if pattern in model_id_lower:
|
||||
return limit
|
||||
|
||||
# 4. Default fallback
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
# First, try to read an explicit configuration from the model_configuration table
|
||||
try:
|
||||
@@ -643,22 +715,32 @@ def is_true_openai_model(model_provider: str, model_name: str) -> bool:
|
||||
"""
|
||||
|
||||
# NOTE: not using the OPENAI_PROVIDER_NAME constant here due to circular import issues
|
||||
if model_provider != "openai":
|
||||
if model_provider != "openai" and model_provider != "litellm_proxy":
|
||||
return False
|
||||
|
||||
model_map = get_model_map()
|
||||
|
||||
def _check_if_model_name_is_openai_provider(model_name: str) -> bool:
|
||||
return (
|
||||
model_name in model_map
|
||||
and model_map[model_name].get("litellm_provider") == "openai"
|
||||
)
|
||||
|
||||
try:
|
||||
model_map = get_model_map()
|
||||
|
||||
# Check if any model exists in litellm's registry with openai prefix
|
||||
# If it's registered as "openai/model-name", it's a real OpenAI model
|
||||
if f"openai/{model_name}" in model_map:
|
||||
return True
|
||||
|
||||
if (
|
||||
model_name in model_map
|
||||
and model_map[model_name].get("litellm_provider") == "openai"
|
||||
):
|
||||
if _check_if_model_name_is_openai_provider(model_name):
|
||||
return True
|
||||
|
||||
if model_name.startswith("azure/"):
|
||||
model_name_with_azure_removed = "/".join(model_name.split("/")[1:])
|
||||
if _check_if_model_name_is_openai_provider(model_name_with_azure_removed):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -38,7 +38,6 @@ from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -271,9 +270,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||
logger.notice("Both OAuth Client ID and Secret are configured.")
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
# Initialize tracing if credentials are provided
|
||||
setup_braintrust_if_creds_available()
|
||||
setup_langfuse_if_creds_available()
|
||||
|
||||
@@ -16,7 +16,6 @@ from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -255,9 +254,7 @@ def _build_documents_blocks(
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = ONYX_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
) -> list[Block]:
|
||||
header_text = (
|
||||
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
|
||||
)
|
||||
header_text = "Reference Documents"
|
||||
seen_docs_identifiers = set()
|
||||
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
|
||||
included_docs = 0
|
||||
|
||||
@@ -19,7 +19,9 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -208,13 +210,31 @@ def handle_regular_answer(
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = slack_channel_config.enable_auto_filters
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
filters=filters,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -872,8 +872,12 @@ def build_request_details(
|
||||
channel_type=channel_type,
|
||||
channel_id=channel,
|
||||
user_id=sender_id or "unknown",
|
||||
message_ts=message_ts,
|
||||
)
|
||||
logger.info(
|
||||
f"build_request_details: Capturing Slack context: "
|
||||
f"channel_type={channel_type} channel_id={channel} message_ts={message_ts}"
|
||||
)
|
||||
logger.info(f"build_request_details: Capturing Slack context: {slack_context}")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
@@ -930,9 +934,11 @@ def build_request_details(
|
||||
channel_type=channel_type,
|
||||
channel_id=channel,
|
||||
user_id=sender,
|
||||
message_ts=None, # Slash commands don't have a message timestamp
|
||||
)
|
||||
logger.info(
|
||||
f"build_request_details: Capturing Slack context for slash command: {slack_context}"
|
||||
f"build_request_details: Capturing Slack context for slash command: "
|
||||
f"channel_type={channel_type} channel_id={channel}"
|
||||
)
|
||||
|
||||
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
|
||||
@@ -1102,7 +1108,7 @@ def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/admin/getting_started/slack_bot_setup
|
||||
# https://docs.onyx.app/admins/getting_started/slack_bot_setup
|
||||
|
||||
# use the retry handlers built into the slack sdk
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
|
||||
@@ -22,6 +22,7 @@ class SlackContext(BaseModel):
|
||||
channel_type: ChannelType
|
||||
channel_id: str
|
||||
user_id: str
|
||||
message_ts: str | None = None # Used as request ID for log correlation
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user