mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
208 Commits
v2.5.9
...
dump-scrip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca3db17b08 | ||
|
|
ffd13b1104 | ||
|
|
1caa860f8e | ||
|
|
7181cc41af | ||
|
|
959b8c320d | ||
|
|
96fd0432ff | ||
|
|
4c73a03f57 | ||
|
|
e57713e376 | ||
|
|
21ea320323 | ||
|
|
bac9c48e53 | ||
|
|
7f79e34aa4 | ||
|
|
f1a81d45a1 | ||
|
|
285755a540 | ||
|
|
89003ad2d8 | ||
|
|
9f93f97259 | ||
|
|
f702eebbe7 | ||
|
|
8487e1856b | ||
|
|
a36445f840 | ||
|
|
7f30293b0e | ||
|
|
619d9528b4 | ||
|
|
6f83c669e7 | ||
|
|
c3e5f48cb4 | ||
|
|
fdf8fe391c | ||
|
|
f1d6bb9e02 | ||
|
|
9a64a717dc | ||
|
|
aa0f475e01 | ||
|
|
75238dc353 | ||
|
|
9e19803244 | ||
|
|
5cabd32638 | ||
|
|
4ccd88c331 | ||
|
|
5a80b98320 | ||
|
|
ff109d9f5c | ||
|
|
4cc276aca9 | ||
|
|
29f0df2c93 | ||
|
|
e2edcf0e0b | ||
|
|
9396fc547d | ||
|
|
c089903aad | ||
|
|
95471f64e9 | ||
|
|
13c1619d01 | ||
|
|
ddb5068847 | ||
|
|
81a4f654c2 | ||
|
|
9393c56a21 | ||
|
|
1ee96ff99c | ||
|
|
6bb00d2c6b | ||
|
|
d9cc923c6a | ||
|
|
bfbba0f036 | ||
|
|
ccf6911f97 | ||
|
|
15c9c2ba8e | ||
|
|
8b3fedf480 | ||
|
|
b8dc0749ee | ||
|
|
d6426458c6 | ||
|
|
941c4d6a54 | ||
|
|
653b65da66 | ||
|
|
503e70be02 | ||
|
|
9c19493160 | ||
|
|
933315646b | ||
|
|
d2061f8a26 | ||
|
|
6a98f0bf3c | ||
|
|
2f4d39d834 | ||
|
|
40f8bcc6f8 | ||
|
|
af9ed73f00 | ||
|
|
bf28041f4e | ||
|
|
395d5927b7 | ||
|
|
c96f24e37c | ||
|
|
070519f823 | ||
|
|
a7dc1c0f3b | ||
|
|
a947e44926 | ||
|
|
a6575b6254 | ||
|
|
31733a9c7c | ||
|
|
5415e2faf1 | ||
|
|
749f720dfd | ||
|
|
eac79cfdf2 | ||
|
|
e3b1202731 | ||
|
|
6df13cc2de | ||
|
|
682f660aa3 | ||
|
|
c4670ea86c | ||
|
|
a6757eb49f | ||
|
|
cd372fb585 | ||
|
|
45fa0d9b32 | ||
|
|
45091f2ee2 | ||
|
|
43a3cb89b9 | ||
|
|
9428eaed8d | ||
|
|
dd29d989ff | ||
|
|
f44daa2116 | ||
|
|
212cbcb683 | ||
|
|
aaad573c3f | ||
|
|
e1325e84ae | ||
|
|
e759cdd4ab | ||
|
|
2ed6607e10 | ||
|
|
ba5b9cf395 | ||
|
|
bab23f62b8 | ||
|
|
d72e2e4081 | ||
|
|
4ed2d08336 | ||
|
|
24a0ceee18 | ||
|
|
d8fba38780 | ||
|
|
5f358a1e20 | ||
|
|
00b0c23e13 | ||
|
|
2103ed9e81 | ||
|
|
2c5ab72312 | ||
|
|
672d1ca8fa | ||
|
|
a418de4287 | ||
|
|
349aba6c02 | ||
|
|
18a7bdc292 | ||
|
|
c658fd4c7d | ||
|
|
f1e87dda5b | ||
|
|
b93edb3e89 | ||
|
|
dc4e76bd64 | ||
|
|
c4242ad17a | ||
|
|
a4dee62660 | ||
|
|
2d2c76ec7b | ||
|
|
d80025138d | ||
|
|
90ec595936 | ||
|
|
f30e88a61b | ||
|
|
9c04e9269f | ||
|
|
8c65fcd193 | ||
|
|
f42e3eb823 | ||
|
|
9b76ed085c | ||
|
|
0eb4d039ae | ||
|
|
3c0b66a174 | ||
|
|
895a8e774e | ||
|
|
c14ea4dbb9 | ||
|
|
80b1e07586 | ||
|
|
59b243d585 | ||
|
|
d4ae3d1cb5 | ||
|
|
ed0a86c681 | ||
|
|
e825e5732f | ||
|
|
a93854ae70 | ||
|
|
fc8767a04f | ||
|
|
6c231e7ad1 | ||
|
|
bac751d4a9 | ||
|
|
3e0f386d5b | ||
|
|
edb6957268 | ||
|
|
0348d11fb2 | ||
|
|
fe514eada0 | ||
|
|
e7672b89bb | ||
|
|
c1494660e1 | ||
|
|
7ee3df6b92 | ||
|
|
54afed0d23 | ||
|
|
1c776fcc73 | ||
|
|
340ddce294 | ||
|
|
e166c1b095 | ||
|
|
84be68ef7c | ||
|
|
90e9af82bf | ||
|
|
7f36fb2a4c | ||
|
|
307464a736 | ||
|
|
1d5c8bdb20 | ||
|
|
6de626ecc3 | ||
|
|
6663c81aa6 | ||
|
|
35ca94c17e | ||
|
|
431f652be8 | ||
|
|
6535d85ceb | ||
|
|
3a349d6ab3 | ||
|
|
ddae686dc7 | ||
|
|
0e42891cbf | ||
|
|
823b28b4a7 | ||
|
|
828036ceb8 | ||
|
|
2a40ceab26 | ||
|
|
f03f2bff78 | ||
|
|
f9a548fbe9 | ||
|
|
8b45f911ff | ||
|
|
ae64ded7bb | ||
|
|
7287e3490d | ||
|
|
7681c11585 | ||
|
|
365e31a7f3 | ||
|
|
dd33886946 | ||
|
|
6cdd5b7d3e | ||
|
|
7b6ae2b72a | ||
|
|
629502ef6a | ||
|
|
927e8addb5 | ||
|
|
14712af431 | ||
|
|
4b38b91674 | ||
|
|
508c248032 | ||
|
|
45db59eab1 | ||
|
|
5a14055a29 | ||
|
|
a698f01cab | ||
|
|
4e4bf197cf | ||
|
|
517b0d1e70 | ||
|
|
7b2b163d4e | ||
|
|
29b28c8352 | ||
|
|
83b624b658 | ||
|
|
d3cd68014a | ||
|
|
64d9fd97ec | ||
|
|
7a9e2ebec6 | ||
|
|
51a69d7e55 | ||
|
|
f19362ce27 | ||
|
|
0c3330c105 | ||
|
|
81cb0f2518 | ||
|
|
beb4e619e7 | ||
|
|
0fa1d5b0ca | ||
|
|
1e30882222 | ||
|
|
42996a63fe | ||
|
|
4a38068192 | ||
|
|
97f66b68c1 | ||
|
|
aeafd83cd1 | ||
|
|
0ba9a873e9 | ||
|
|
b72bac993f | ||
|
|
9572c63089 | ||
|
|
c4505cdb06 | ||
|
|
9055691c38 | ||
|
|
1afa7b0689 | ||
|
|
72c96a502e | ||
|
|
093b399472 | ||
|
|
d89dd3c76b | ||
|
|
a24d0aa26d | ||
|
|
5e581c2c60 | ||
|
|
17ea20ef5c | ||
|
|
0b8207ef4c | ||
|
|
c26da8dc75 |
4
.github/dependabot.yml
vendored
4
.github/dependabot.yml
vendored
@@ -5,7 +5,7 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 4
|
||||
default-days: 7
|
||||
open-pull-requests-limit: 3
|
||||
assignees:
|
||||
- "jmelahman"
|
||||
@@ -16,7 +16,7 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
cooldown:
|
||||
default-days: 4
|
||||
default-days: 7
|
||||
open-pull-requests-limit: 3
|
||||
assignees:
|
||||
- "jmelahman"
|
||||
|
||||
33
.github/workflows/check-lazy-imports.yml
vendored
33
.github/workflows/check-lazy-imports.yml
vendored
@@ -1,33 +0,0 @@
|
||||
name: Check Lazy Imports
|
||||
concurrency:
|
||||
group: Check-Lazy-Imports-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
check-lazy-imports:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check lazy imports
|
||||
run: python3 backend/scripts/check_lazy_imports.py
|
||||
61
.github/workflows/deployment.yml
vendored
61
.github/workflows/deployment.yml
vendored
@@ -83,6 +83,47 @@ jobs:
|
||||
echo "sanitized-tag=$SANITIZED_TAG"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
check-version-tag:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
with:
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
enable-cache: false
|
||||
|
||||
- name: Validate tag is versioned correctly
|
||||
run: |
|
||||
uv run --no-sync --with release-tag tag --check
|
||||
|
||||
notify-slack-on-tag-check-failure:
|
||||
needs:
|
||||
- check-version-tag
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
@@ -100,7 +141,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -158,7 +199,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -266,7 +307,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -332,7 +373,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -445,7 +486,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -502,7 +543,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -610,7 +651,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -674,7 +715,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -867,7 +908,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -957,7 +998,7 @@ jobs:
|
||||
timeout-minutes: 90
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
4
.github/workflows/nightly-scan-licenses.yml
vendored
4
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -28,12 +28,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
|
||||
@@ -7,6 +7,9 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -29,6 +32,9 @@ env:
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
|
||||
|
||||
# Jira
|
||||
JIRA_ADMIN_API_TOKEN: ${{ secrets.JIRA_ADMIN_API_TOKEN }}
|
||||
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
@@ -46,7 +52,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -74,12 +80,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
MODEL_SERVER_HOST: "disabled"
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -107,6 +114,7 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -162,7 +170,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
5
.github/workflows/pr-helm-chart-testing.yml
vendored
5
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -7,6 +7,9 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
@@ -21,7 +24,7 @@ jobs:
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
107
.github/workflows/pr-integration-tests.yml
vendored
107
.github/workflows/pr-integration-tests.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -40,7 +43,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -71,10 +74,24 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -95,9 +112,13 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
@@ -108,10 +129,24 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -132,9 +167,14 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
|
||||
build-integration-image:
|
||||
@@ -143,7 +183,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -159,16 +199,40 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
@@ -195,7 +259,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -210,23 +274,28 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -372,7 +441,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
9
.github/workflows/pr-jest-tests.yml
vendored
9
.github/workflows/pr-jest-tests.yml
vendored
@@ -3,7 +3,8 @@ concurrency:
|
||||
group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
on:
|
||||
push:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -15,15 +16,15 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'npm'
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
104
.github/workflows/pr-mit-integration-tests.yml
vendored
104
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -6,6 +6,9 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
types: [checks_requested]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -37,7 +40,7 @@ jobs:
|
||||
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -67,10 +70,24 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -91,9 +108,14 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -102,10 +124,24 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -126,9 +162,14 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
build-integration-image:
|
||||
runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache"]
|
||||
@@ -136,10 +177,24 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -156,12 +211,22 @@ jobs:
|
||||
env:
|
||||
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
TAG: integration-test-${{ github.run_id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
run: |
|
||||
cd backend && docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
|
||||
@@ -188,7 +253,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -203,21 +268,26 @@ jobs:
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
# NOTE: don't need web server for integration tests
|
||||
- name: Start Docker containers
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
MCP_SERVER_ENABLED=true \
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
|
||||
77
.github/workflows/pr-playwright-tests.yml
vendored
77
.github/workflows/pr-playwright-tests.yml
vendored
@@ -3,7 +3,8 @@ concurrency:
|
||||
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
on:
|
||||
push:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -52,10 +53,24 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -76,9 +91,14 @@ jobs:
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache
|
||||
type=registry,ref=onyxdotapp/onyx-web-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
build-backend-image:
|
||||
@@ -88,10 +108,24 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -112,9 +146,13 @@ jobs:
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
@@ -125,10 +163,24 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
@@ -149,9 +201,14 @@ jobs:
|
||||
tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }}
|
||||
push: true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
cache-to: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
|
||||
|
||||
playwright-tests:
|
||||
@@ -172,13 +229,13 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: 'npm'
|
||||
@@ -408,12 +465,12 @@ jobs:
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
# uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # ratchet:actions/setup-node@v4
|
||||
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
|
||||
65
.github/workflows/pr-python-checks.yml
vendored
65
.github/workflows/pr-python-checks.yml
vendored
@@ -9,29 +9,14 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
validate-requirements:
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@caf0cab7a618c569241d31dcd442f54681755d39 # ratchet:astral-sh/setup-uv@v3
|
||||
# TODO: Enable caching once there is a uv.lock file checked in.
|
||||
# with:
|
||||
# enable-cache: true
|
||||
|
||||
- name: Validate requirements lock files
|
||||
run: ./backend/scripts/compile_requirements.py --check
|
||||
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
@@ -42,7 +27,7 @@ jobs:
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -55,35 +40,10 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema
|
||||
shell: bash
|
||||
working-directory: backend
|
||||
env:
|
||||
PYTHONPATH: "."
|
||||
run: |
|
||||
python scripts/onyx_openapi_schema.py --filename generated/openapi.json
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Generate OpenAPI Python client
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
run: |
|
||||
docker run --rm \
|
||||
-v "${{ github.workspace }}/backend/generated:/local" \
|
||||
openapitools/openapi-generator-cli generate \
|
||||
-i /local/openapi.json \
|
||||
-g python \
|
||||
-o /local/onyx_openapi_client \
|
||||
--package-name onyx_openapi_client \
|
||||
--skip-validate-spec \
|
||||
--openapi-normalizer "SIMPLIFY_ONEOF_ANYOF=true,SET_OAS3_NULLABLE=true"
|
||||
ods openapi all
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
@@ -101,11 +61,8 @@ jobs:
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Check import order with reorder-python-imports
|
||||
working-directory: ./backend
|
||||
run: |
|
||||
find ./onyx -name "*.py" | xargs reorder-python-imports --py311-plus
|
||||
|
||||
- name: Check code formatting with Black
|
||||
working-directory: ./backend
|
||||
run: black --check .
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
15
.github/workflows/pr-python-connector-tests.yml
vendored
15
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -7,6 +7,9 @@ on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
@@ -130,12 +133,13 @@ jobs:
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -157,16 +161,20 @@ jobs:
|
||||
hubspot:
|
||||
- 'backend/onyx/connectors/hubspot/**'
|
||||
- 'backend/tests/daily/connectors/hubspot/**'
|
||||
- 'uv.lock'
|
||||
salesforce:
|
||||
- 'backend/onyx/connectors/salesforce/**'
|
||||
- 'backend/tests/daily/connectors/salesforce/**'
|
||||
- 'uv.lock'
|
||||
github:
|
||||
- 'backend/onyx/connectors/github/**'
|
||||
- 'backend/tests/daily/connectors/github/**'
|
||||
- 'uv.lock'
|
||||
file_processing:
|
||||
- 'backend/onyx/file_processing/**'
|
||||
- 'uv.lock'
|
||||
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, and GitHub)
|
||||
- name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda)
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test \
|
||||
@@ -179,7 +187,8 @@ jobs:
|
||||
backend/tests/daily/connectors \
|
||||
--ignore backend/tests/daily/connectors/hubspot \
|
||||
--ignore backend/tests/daily/connectors/salesforce \
|
||||
--ignore backend/tests/daily/connectors/github
|
||||
--ignore backend/tests/daily/connectors/github \
|
||||
--ignore backend/tests/daily/connectors/coda
|
||||
|
||||
- name: Run HubSpot Connector Tests
|
||||
if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }}
|
||||
|
||||
4
.github/workflows/pr-python-model-tests.yml
vendored
4
.github/workflows/pr-python-model-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
|
||||
9
.github/workflows/pr-python-tests.yml
vendored
9
.github/workflows/pr-python-tests.yml
vendored
@@ -9,6 +9,9 @@ on:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -23,15 +26,13 @@ jobs:
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
|
||||
30
.github/workflows/pr-quality-checks.yml
vendored
30
.github/workflows/pr-quality-checks.yml
vendored
@@ -6,32 +6,42 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request: null
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
quality-checks:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-quality-checks"]
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # ratchet:actions/setup-python@v6
|
||||
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet:pre-commit/action@v3.0.1
|
||||
env:
|
||||
# uv-run is mypy's id and mypy is covered by the Python Checks which caches dependencies better.
|
||||
SKIP: uv-run
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@91fd7d7cf70ae1dee9f4f44e7dfa5d1073fe6623 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}
|
||||
prek-version: '0.2.21'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
with:
|
||||
|
||||
40
.github/workflows/release-devtools.yml
vendored
Normal file
40
.github/workflows/release-devtools.yml
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
name: Release Devtools
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "ods/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-devtools
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- {goos: "linux", goarch: "amd64"}
|
||||
- {goos: "linux", goarch: "arm64"}
|
||||
- {goos: "windows", goarch: "amd64"}
|
||||
- {goos: "windows", goarch: "arm64"}
|
||||
- {goos: "darwin", goarch: "amd64"}
|
||||
- {goos: "darwin", goarch: "arm64"}
|
||||
- {goos: "", goarch: ""}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: tools/ods
|
||||
- run: uv publish
|
||||
working-directory: tools/ods
|
||||
2
.github/workflows/sync_foss.yml
vendored
2
.github/workflows/sync_foss.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: read
|
||||
steps:
|
||||
- name: Checkout main Onyx repo
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
2
.github/workflows/tag-nightly.yml
vendored
2
.github/workflows/tag-nightly.yml
vendored
@@ -18,7 +18,7 @@ jobs:
|
||||
# see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we
|
||||
# implement here which needs an actual user's deploy key
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
ssh-key: "${{ secrets.DEPLOY_KEY }}"
|
||||
persist-credentials: true
|
||||
|
||||
8
.github/workflows/zizmor.yml
vendored
8
.github/workflows/zizmor.yml
vendored
@@ -17,15 +17,17 @@ jobs:
|
||||
security-events: write # needed for SARIF uploads
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # ratchet:actions/checkout@v6.0.0
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@5a7eac68fb9809dea845d802897dc5c723910fa3 # ratchet:astral-sh/setup-uv@v7.1.3
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
|
||||
with:
|
||||
enable-cache: false
|
||||
|
||||
- name: Run zizmor
|
||||
run: uvx zizmor==1.16.3 --format=sarif . > results.sarif
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -49,5 +49,10 @@ CLAUDE.md
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
node_modules
|
||||
|
||||
# MCP configs
|
||||
.playwright-mcp
|
||||
|
||||
# plans
|
||||
plans/
|
||||
|
||||
@@ -5,36 +5,60 @@ default_install_hook_types:
|
||||
- post-rewrite
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# This revision is from https://github.com/astral-sh/uv-pre-commit/pull/53
|
||||
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
|
||||
hooks:
|
||||
- id: uv-sync
|
||||
- id: uv-run
|
||||
name: mypy
|
||||
args: ["mypy"]
|
||||
pass_filenames: true
|
||||
files: ^backend/.*\.py$
|
||||
name: Check lazy imports
|
||||
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export dev.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export ee.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
- id: uv-export
|
||||
name: uv-export model_server.txt
|
||||
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
|
||||
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
# args: ["--all-extras", "mypy"]
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.6.0
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
files: ^.github/
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: v1.7.8
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
- id: actionlint
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 25.1.0
|
||||
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3.11
|
||||
|
||||
# this is a fork which keeps compatibility with black
|
||||
- repo: https://github.com/wimglenn/reorder-python-imports-black
|
||||
rev: v3.14.0
|
||||
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
|
||||
hooks:
|
||||
- id: reorder-python-imports
|
||||
args: ['--py311-plus', '--application-directories=backend/']
|
||||
@@ -46,26 +70,32 @@ repos:
|
||||
# These settings will remove unused imports with side effects
|
||||
# Note: The repo currently does not and should not have imports with side effects
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.3.1
|
||||
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.11.4
|
||||
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: v3.1.0
|
||||
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
|
||||
hooks:
|
||||
- id: prettier
|
||||
types_or: [html, css, javascript, ts, tsx]
|
||||
language_version: system
|
||||
|
||||
- repo: https://github.com/sirwart/ripsecrets
|
||||
rev: v0.1.11
|
||||
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
@@ -82,8 +112,9 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: check-lazy-imports
|
||||
name: Check lazy imports
|
||||
entry: python3 backend/scripts/check_lazy_imports.py
|
||||
- id: typescript-check
|
||||
name: TypeScript type check
|
||||
entry: bash -c 'cd web && npm run types:check'
|
||||
language: system
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
pass_filenames: false
|
||||
files: ^web/.*\.(ts|tsx)$
|
||||
|
||||
20
.vscode/env_template.txt
vendored
20
.vscode/env_template.txt
vendored
@@ -5,11 +5,8 @@
|
||||
# For local dev, often user Authentication is not needed
|
||||
AUTH_TYPE=disabled
|
||||
|
||||
# Skip warm up for dev
|
||||
SKIP_WARM_UP=True
|
||||
|
||||
# Always keep these on for Dev
|
||||
# Logs all model prompts to stdout
|
||||
# Logs model prompts, reasoning, and answer to stdout
|
||||
LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
# More verbose logging
|
||||
LOG_LEVEL=debug
|
||||
@@ -37,31 +34,16 @@ OPENAI_API_KEY=<REPLACE THIS>
|
||||
GEN_AI_MODEL_VERSION=gpt-4o
|
||||
FAST_GEN_AI_MODEL_VERSION=gpt-4o
|
||||
|
||||
# For Onyx Slack Bot, overrides the UI values so no need to set this up via UI every time
|
||||
# Only needed if using OnyxBot
|
||||
#ONYX_BOT_SLACK_APP_TOKEN=<REPLACE THIS>
|
||||
#ONYX_BOT_SLACK_BOT_TOKEN=<REPLACE THIS>
|
||||
|
||||
|
||||
# Python stuff
|
||||
PYTHONPATH=../backend
|
||||
PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
# Internet Search
|
||||
EXA_API_KEY=<REPLACE THIS>
|
||||
|
||||
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
# S3 File Store Configuration (MinIO for local development)
|
||||
S3_ENDPOINT_URL=http://localhost:9004
|
||||
S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket
|
||||
|
||||
17
.vscode/launch.template.jsonc
vendored
17
.vscode/launch.template.jsonc
vendored
@@ -133,8 +133,6 @@
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
},
|
||||
// For the listener to access the Slack API,
|
||||
// ONYX_BOT_SLACK_APP_TOKEN & ONYX_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
"consoleName": "Slack Bot",
|
||||
@@ -510,7 +508,6 @@
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"stopOnEntry": true,
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
@@ -556,10 +553,10 @@
|
||||
"name": "Install Python Requirements",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"-c",
|
||||
"pip install -r backend/requirements/default.txt && pip install -r backend/requirements/dev.txt && pip install -r backend/requirements/ee.txt && pip install -r backend/requirements/model_server.txt"
|
||||
"sync",
|
||||
"--all-extras"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
@@ -572,14 +569,14 @@
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/onyx_openapi_schema.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"program": "backend/scripts/onyx_openapi_schema.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
"PYTHONPATH": "backend"
|
||||
},
|
||||
"args": ["--filename", "generated/openapi.json"]
|
||||
"args": ["--filename", "backend/generated/openapi.json", "--generate-python-client"]
|
||||
},
|
||||
{
|
||||
// script to debug multi tenant db issues
|
||||
|
||||
@@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
|
||||
## KEY NOTES
|
||||
|
||||
- If you run into any missing python dependency errors, try running your command with `source backend/.venv/bin/activate` \
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
|
||||
@@ -71,12 +71,12 @@ If using a higher version, sometimes some libraries will not be available (i.e.
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use pip and recommend creating a virtual environment.
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
@@ -95,33 +95,15 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/combined.txt
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
or
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
Fix vscode/cursor auto-imports:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
In the activated Python virtualenv, install Playwright for Python by running:
|
||||
|
||||
```bash
|
||||
playwright install
|
||||
```
|
||||
|
||||
You may have to deactivate and reactivate your virtualenv for `playwright` to appear on your path.
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
@@ -130,7 +112,7 @@ to manage your Node installations. Once installed, you can run
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
@@ -144,21 +126,15 @@ npm i
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
With the virtual environment active, install the pre-commit library with:
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
Then, from the `onyx/backend` directory, run:
|
||||
|
||||
```bash
|
||||
pre-commit install
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `python -m mypy .` from the `onyx/backend` directory.
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
|
||||
@@ -7,8 +7,12 @@ Onyx migrations use a generic single-database configuration with an async dbapi.
|
||||
|
||||
## To generate new migrations:
|
||||
|
||||
run from onyx/backend:
|
||||
`alembic revision --autogenerate -m <DESCRIPTION_OF_MIGRATION>`
|
||||
From onyx/backend, run:
|
||||
`alembic revision -m <DESCRIPTION_OF_MIGRATION>`
|
||||
|
||||
Note: you cannot use the `--autogenerate` flag as the automatic schema parsing does not work.
|
||||
|
||||
Manually populate the upgrade and downgrade in your new migration.
|
||||
|
||||
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
|
||||
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add is_clarification to chat_message
|
||||
|
||||
Revision ID: 18b5b2524446
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-01-16
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "18b5b2524446"
|
||||
down_revision = "87c52ec39f84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column(
|
||||
"is_clarification", sa.Boolean(), nullable=False, server_default="false"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_clarification")
|
||||
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
104
backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""add_open_url_tool
|
||||
|
||||
Revision ID: 4f8a2b3c1d9e
|
||||
Revises: a852cbe15577
|
||||
Create Date: 2025-11-24 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4f8a2b3c1d9e"
|
||||
down_revision = "a852cbe15577"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
OPEN_URL_TOOL = {
|
||||
"name": "OpenURLTool",
|
||||
"display_name": "Open URL",
|
||||
"description": (
|
||||
"The Open URL Action allows the agent to fetch and read contents of web pages."
|
||||
),
|
||||
"in_code_tool_id": "OpenURLTool",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if tool already exists
|
||||
existing = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
|
||||
if existing:
|
||||
tool_id = existing[0]
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled)
|
||||
"""
|
||||
),
|
||||
OPEN_URL_TOOL,
|
||||
)
|
||||
# Get the newly inserted tool's id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
tool_id = result[0] # type: ignore
|
||||
|
||||
# Associate the tool with all existing personas
|
||||
# Get all persona IDs
|
||||
persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall()
|
||||
|
||||
for (persona_id,) in persona_ids:
|
||||
# Check if association already exists
|
||||
exists = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT 1 FROM persona__tool
|
||||
WHERE persona_id = :persona_id AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
).fetchone()
|
||||
|
||||
if not exists:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (:persona_id, :tool_id)
|
||||
"""
|
||||
),
|
||||
{"persona_id": persona_id, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the tool on downgrade since it's fine to have it around.
|
||||
# If we upgrade again, it will be a no-op.
|
||||
pass
|
||||
@@ -0,0 +1,55 @@
|
||||
"""update_default_persona_prompt
|
||||
|
||||
Revision ID: 5e6f7a8b9c0d
|
||||
Revises: 4f8a2b3c1d9e
|
||||
Create Date: 2025-11-30 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5e6f7a8b9c0d"
|
||||
down_revision = "4f8a2b3c1d9e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].{citation_reminder_or_empty}
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :system_prompt
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't revert the system prompt on downgrade since we don't know
|
||||
# what the previous value was. The new prompt is a reasonable default.
|
||||
pass
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add display_name to model_configuration
|
||||
|
||||
Revision ID: 7bd55f264e1b
|
||||
Revises: e8f0d2a38171
|
||||
Create Date: 2025-12-04
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7bd55f264e1b"
|
||||
down_revision = "e8f0d2a38171"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"model_configuration",
|
||||
sa.Column("display_name", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("model_configuration", "display_name")
|
||||
@@ -0,0 +1,55 @@
|
||||
"""update_default_system_prompt
|
||||
|
||||
Revision ID: 87c52ec39f84
|
||||
Revises: 7bd55f264e1b
|
||||
Create Date: 2025-12-05 15:54:06.002452
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "87c52ec39f84"
|
||||
down_revision = "7bd55f264e1b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
DEFAULT_SYSTEM_PROMPT = """
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
|
||||
The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]]
|
||||
|
||||
# Response Style
|
||||
You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging.
|
||||
You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline.
|
||||
For code you prefer to use Markdown and specify the language.
|
||||
You can use horizontal rules (---) to separate sections of your responses.
|
||||
You can use Markdown tables to format your responses for data, lists, and other structured information.
|
||||
""".lstrip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET system_prompt = :system_prompt
|
||||
WHERE id = :persona_id
|
||||
"""
|
||||
),
|
||||
{"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't revert the system prompt on downgrade since we don't know
|
||||
# what the previous value was. The new prompt is a reasonable default.
|
||||
pass
|
||||
@@ -0,0 +1,62 @@
|
||||
"""update_default_tool_descriptions
|
||||
|
||||
Revision ID: a01bf2971c5d
|
||||
Revises: 87c52ec39f84
|
||||
Create Date: 2025-12-16 15:21:25.656375
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a01bf2971c5d"
|
||||
down_revision = "18b5b2524446"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# new tool descriptions (12/2025)
|
||||
TOOL_DESCRIPTIONS = {
|
||||
"SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.",
|
||||
"ImageGenerationTool": (
|
||||
"The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. "
|
||||
"The action will be used when the user asks the agent to generate an image."
|
||||
),
|
||||
"WebSearchTool": (
|
||||
"The Web Search Action allows the agent "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
"KnowledgeGraphTool": (
|
||||
"The Knowledge Graph Search Action allows the agent to search the "
|
||||
"Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, "
|
||||
"and it requires the Knowledge Graph to be enabled."
|
||||
),
|
||||
"OktaProfileTool": (
|
||||
"The Okta Profile Action allows the agent to fetch the current user's information from Okta. "
|
||||
"This may include the user's name, email, phone number, address, and other details such as their "
|
||||
"manager and direct reports."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
417
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
417
backend/alembic/versions/a852cbe15577_new_chat_history.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""New Chat History
|
||||
|
||||
Revision ID: a852cbe15577
|
||||
Revises: 6436661d5b65
|
||||
Create Date: 2025-11-08 15:16:37.781308
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a852cbe15577"
|
||||
down_revision = "6436661d5b65"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Drop old research/agent tables (CASCADE handles dependencies)
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE")
|
||||
op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE")
|
||||
|
||||
# 2. ChatMessage table changes
|
||||
# Rename columns and add FKs
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message", new_column_name="parent_message_id"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_parent_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["parent_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message",
|
||||
new_column_name="latest_child_message_id",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_latest_child_message_id",
|
||||
"chat_message",
|
||||
"chat_message",
|
||||
["latest_child_message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Add new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop old columns
|
||||
op.drop_column("chat_message", "rephrased_query")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
op.drop_column("chat_message", "overridden_model")
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
op.drop_column("chat_message", "refined_answer_improvement")
|
||||
op.drop_column("chat_message", "research_type")
|
||||
op.drop_column("chat_message", "research_plan")
|
||||
op.drop_column("chat_message", "research_answer_purpose")
|
||||
|
||||
# 3. ToolCall table changes
|
||||
# Drop the unique constraint first
|
||||
op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique")
|
||||
|
||||
# Delete orphaned tool_call rows (those without valid chat_message)
|
||||
op.execute(
|
||||
"DELETE FROM tool_call WHERE message_id NOT IN (SELECT id FROM chat_message)"
|
||||
)
|
||||
|
||||
# Add chat_session_id as nullable first, populate, then make NOT NULL
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
|
||||
# Populate chat_session_id from the related chat_message
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE tool_call
|
||||
SET chat_session_id = chat_message.chat_session_id
|
||||
FROM chat_message
|
||||
WHERE tool_call.message_id = chat_message.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Now make it NOT NULL and add FK
|
||||
op.alter_column("tool_call", "chat_session_id", nullable=False)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_chat_session_id",
|
||||
"tool_call",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Rename message_id and make nullable, recreate FK with CASCADE
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"message_id",
|
||||
new_column_name="parent_chat_message_id",
|
||||
nullable=True,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_parent_chat_message_id",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["parent_chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add parent_tool_call_id with FK
|
||||
op.add_column(
|
||||
"tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_tool_call_parent_tool_call_id",
|
||||
"tool_call",
|
||||
"tool_call",
|
||||
["parent_tool_call_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add other new columns
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("turn_number", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""),
|
||||
)
|
||||
op.add_column("tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tool_call_tokens", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
)
|
||||
|
||||
# Rename columns
|
||||
op.alter_column(
|
||||
"tool_call", "tool_arguments", new_column_name="tool_call_arguments"
|
||||
)
|
||||
op.alter_column("tool_call", "tool_result", new_column_name="tool_call_response")
|
||||
|
||||
# Change tool_call_response type from JSONB to Text
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE TEXT
|
||||
USING tool_call_response::text
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop old columns
|
||||
op.drop_column("tool_call", "tool_name")
|
||||
|
||||
# 4. Create new association table
|
||||
op.create_table(
|
||||
"tool_call__search_doc",
|
||||
sa.Column("tool_call_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["search_doc_id"], ["search_doc.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# 5. Persona table change
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"replace_base_system_prompt",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse persona changes
|
||||
op.drop_column("persona", "replace_base_system_prompt")
|
||||
|
||||
# Drop new association table
|
||||
op.drop_table("tool_call__search_doc")
|
||||
|
||||
# Reverse ToolCall changes
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("tool_name", sa.String(), nullable=False, server_default=""),
|
||||
)
|
||||
|
||||
# Change tool_call_response back to JSONB
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE tool_call
|
||||
ALTER COLUMN tool_call_response TYPE JSONB
|
||||
USING tool_call_response::jsonb
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result")
|
||||
op.alter_column(
|
||||
"tool_call", "tool_call_arguments", new_column_name="tool_arguments"
|
||||
)
|
||||
|
||||
op.drop_column("tool_call", "generated_images")
|
||||
op.drop_column("tool_call", "tool_call_tokens")
|
||||
op.drop_column("tool_call", "reasoning_tokens")
|
||||
op.drop_column("tool_call", "tool_call_id")
|
||||
op.drop_column("tool_call", "turn_number")
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("tool_call", "parent_tool_call_id")
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_tool_call_parent_chat_message_id", "tool_call", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"tool_call",
|
||||
"parent_chat_message_id",
|
||||
new_column_name="message_id",
|
||||
nullable=False,
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "chat_session_id")
|
||||
|
||||
op.create_unique_constraint("uq_tool_call_message_id", "tool_call", ["message_id"])
|
||||
|
||||
# Reverse ChatMessage changes
|
||||
# Note: research_answer_purpose and research_type were originally String columns,
|
||||
# not Enum types (see migrations 5ae8240accb3 and f8a9b2c3d4e5)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_answer_purpose", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("research_type", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("overridden_model", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
|
||||
)
|
||||
|
||||
op.drop_column("chat_message", "reasoning_tokens")
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message",
|
||||
"latest_child_message_id",
|
||||
new_column_name="latest_child_message",
|
||||
)
|
||||
|
||||
op.drop_constraint(
|
||||
"fk_chat_message_parent_message_id", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.alter_column(
|
||||
"chat_message", "parent_message_id", new_column_name="parent_message"
|
||||
)
|
||||
|
||||
# Recreate agent sub question and sub query tables
|
||||
op.create_table(
|
||||
"agent__sub_question",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_question", sa.Text(), nullable=False),
|
||||
sa.Column("level", sa.Integer(), nullable=False),
|
||||
sa.Column("level_question_num", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_answer", sa.Text(), nullable=False),
|
||||
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("parent_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column("sub_query", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"agent__sub_query__search_doc",
|
||||
sa.Column("sub_query_id", sa.Integer(), nullable=False),
|
||||
sa.Column("search_doc_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]),
|
||||
sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"),
|
||||
)
|
||||
|
||||
# Recreate research agent tables
|
||||
op.create_table(
|
||||
"research_agent_iteration",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("purpose", sa.String(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id"], ["chat_message.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"primary_question_id",
|
||||
"iteration_nr",
|
||||
name="_research_agent_iteration_unique_constraint",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"research_agent_iteration_sub_step",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("primary_question_id", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_nr", sa.Integer(), nullable=False),
|
||||
sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("sub_step_instructions", sa.String(), nullable=True),
|
||||
sa.Column("sub_step_tool_id", sa.Integer(), nullable=True),
|
||||
sa.Column("reasoning", sa.String(), nullable=True),
|
||||
sa.Column("sub_answer", sa.String(), nullable=True),
|
||||
sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("claims", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("is_web_fetch", sa.Boolean(), nullable=True),
|
||||
sa.Column("queries", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("generated_images", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("additional_data", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("file_ids", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["primary_question_id", "iteration_nr"],
|
||||
[
|
||||
"research_agent_iteration.primary_question_id",
|
||||
"research_agent_iteration.iteration_nr",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""add status to mcp server and make auth fields nullable
|
||||
|
||||
Revision ID: e8f0d2a38171
|
||||
Revises: ed9e44312505
|
||||
Create Date: 2025-11-28 11:15:37.667340
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.db.enums import ( # type: ignore[import-untyped]
|
||||
MCPTransport,
|
||||
MCPAuthenticationType,
|
||||
MCPAuthenticationPerformer,
|
||||
MCPServerStatus,
|
||||
)
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e8f0d2a38171"
|
||||
down_revision = "ed9e44312505"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make auth fields nullable
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"transport",
|
||||
existing_type=sa.Enum(MCPTransport, name="mcp_transport", native_enum=False),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_type",
|
||||
existing_type=sa.Enum(
|
||||
MCPAuthenticationType, name="mcp_authentication_type", native_enum=False
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_performer",
|
||||
existing_type=sa.Enum(
|
||||
MCPAuthenticationPerformer,
|
||||
name="mcp_authentication_performer",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Add status column with default
|
||||
op.add_column(
|
||||
"mcp_server",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(MCPServerStatus, name="mcp_server_status", native_enum=False),
|
||||
nullable=False,
|
||||
server_default="CREATED",
|
||||
),
|
||||
)
|
||||
|
||||
# For existing records, mark status as CONNECTED
|
||||
bind = op.get_bind()
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE mcp_server
|
||||
SET status = 'CONNECTED'
|
||||
WHERE status != 'CONNECTED'
|
||||
and admin_connection_config_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove status column
|
||||
op.drop_column("mcp_server", "status")
|
||||
|
||||
# Make auth fields non-nullable (set defaults first)
|
||||
op.execute(
|
||||
"UPDATE mcp_server SET transport = 'STREAMABLE_HTTP' WHERE transport IS NULL"
|
||||
)
|
||||
op.execute("UPDATE mcp_server SET auth_type = 'NONE' WHERE auth_type IS NULL")
|
||||
op.execute(
|
||||
"UPDATE mcp_server SET auth_performer = 'ADMIN' WHERE auth_performer IS NULL"
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"transport",
|
||||
existing_type=sa.Enum(MCPTransport, name="mcp_transport", native_enum=False),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_type",
|
||||
existing_type=sa.Enum(
|
||||
MCPAuthenticationType, name="mcp_authentication_type", native_enum=False
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"mcp_server",
|
||||
"auth_performer",
|
||||
existing_type=sa.Enum(
|
||||
MCPAuthenticationPerformer,
|
||||
name="mcp_authentication_performer",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
34
backend/alembic/versions/ed9e44312505_add_icon_name_field.py
Normal file
34
backend/alembic/versions/ed9e44312505_add_icon_name_field.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Add icon_name field
|
||||
|
||||
Revision ID: ed9e44312505
|
||||
Revises: 5e6f7a8b9c0d
|
||||
Create Date: 2025-12-03 16:35:07.828393
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ed9e44312505"
|
||||
down_revision = "5e6f7a8b9c0d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add icon_name column
|
||||
op.add_column("persona", sa.Column("icon_name", sa.String(), nullable=True))
|
||||
|
||||
# Remove old icon columns
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "icon_color")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add old icon columns
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
|
||||
# Remove icon_name column
|
||||
op.drop_column("persona", "icon_name")
|
||||
@@ -41,6 +41,10 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
# In seconds, default is 30 minutes
|
||||
JIRA_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("JIRA_PERMISSION_GROUP_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
|
||||
@@ -199,10 +199,7 @@ def fetch_persona_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -231,10 +228,7 @@ def fetch_persona_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatSession.persona_id == persona_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -265,10 +259,7 @@ def fetch_assistant_message_analytics(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -299,10 +290,7 @@ def fetch_assistant_unique_users(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
@@ -332,10 +320,7 @@ def fetch_assistant_unique_users_total(
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatSession.persona_id == assistant_id,
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
|
||||
@@ -55,18 +55,7 @@ def get_empty_chat_messages_entries__paginated(
|
||||
|
||||
# Get assistant name (from session persona, or alternate if specified)
|
||||
assistant_name = None
|
||||
if message.alternate_assistant_id:
|
||||
# If there's an alternate assistant, we need to fetch it
|
||||
from onyx.db.models import Persona
|
||||
|
||||
alternate_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == message.alternate_assistant_id)
|
||||
.first()
|
||||
)
|
||||
if alternate_persona:
|
||||
assistant_name = alternate_persona.name
|
||||
elif chat_session.persona:
|
||||
if chat_session.persona:
|
||||
assistant_name = chat_session.persona.name
|
||||
|
||||
message_skeletons.append(
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -362,14 +363,29 @@ def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
|
||||
def _add_user__user_group_relationships__no_commit(
|
||||
db_session: Session, user_group_id: int, user_ids: list[UUID]
|
||||
) -> list[User__UserGroup]:
|
||||
"""NOTE: does not commit the transaction."""
|
||||
relationships = [
|
||||
User__UserGroup(user_id=user_id, user_group_id=user_group_id)
|
||||
for user_id in user_ids
|
||||
]
|
||||
db_session.add_all(relationships)
|
||||
return relationships
|
||||
) -> None:
|
||||
"""NOTE: does not commit the transaction.
|
||||
|
||||
This function is idempotent - it will skip users who are already in the group
|
||||
to avoid duplicate key violations during concurrent operations or re-syncs.
|
||||
Uses ON CONFLICT DO NOTHING to keep inserts atomic under concurrency.
|
||||
"""
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
insert_stmt = (
|
||||
insert(User__UserGroup)
|
||||
.values(
|
||||
[
|
||||
{"user_id": user_id, "user_group_id": user_group_id}
|
||||
for user_id in user_ids
|
||||
]
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id]
|
||||
)
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
|
||||
def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
@@ -3,12 +3,15 @@ from collections.abc import Generator
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -19,7 +22,7 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
logger.info(f"Processing groups for user: {user}")
|
||||
|
||||
email = user.email
|
||||
if not email:
|
||||
@@ -31,6 +34,8 @@ def _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
else:
|
||||
logger.error(f"user result missing username field: {user}")
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
@@ -64,6 +69,92 @@ def _build_group_member_email_map(
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def _build_group_member_email_map_from_onyx_users(
|
||||
confluence_client: OnyxConfluence,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Hacky, but it's the only way to do this as long as the
|
||||
Confluence APIs are broken.
|
||||
|
||||
This is fixed in Confluence Data Center 10.1.0, so first choice
|
||||
is to tell users to upgrade to 10.1.0.
|
||||
https://jira.atlassian.com/browse/CONFSERVER-95999
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# don't include external since they are handled by the "through confluence"
|
||||
# user fetching mechanism
|
||||
user_emails = [
|
||||
user.email for user in get_all_users(db_session, include_external=False)
|
||||
]
|
||||
|
||||
def _infer_username_from_email(email: str) -> str:
|
||||
return email.split("@")[0]
|
||||
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for email in user_emails:
|
||||
logger.info(f"Processing groups for user with email: {email}")
|
||||
try:
|
||||
user_name = _infer_username_from_email(email)
|
||||
response = confluence_client.get_user_details_by_username(user_name)
|
||||
user_key = response.get("userKey")
|
||||
if not user_key:
|
||||
logger.error(f"User key not found for user with email {email}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user_key):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not all_users_groups:
|
||||
msg = f"No groups found for user with email: {email}"
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.info(
|
||||
f"Found groups {all_users_groups} for user with email {email}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Error getting user details for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def _build_final_group_to_member_email_map(
|
||||
confluence_client: OnyxConfluence,
|
||||
cc_pair_id: int,
|
||||
# if set, will infer confluence usernames from onyx users in addition to using the
|
||||
# confluence users API. This is a hacky workaround for the fact that the Confluence
|
||||
# users API is broken before Confluence Data Center 10.1.0.
|
||||
use_onyx_users: bool = CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC,
|
||||
) -> dict[str, set[str]]:
|
||||
group_to_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
group_to_member_email_map_from_onyx_users = (
|
||||
(
|
||||
_build_group_member_email_map_from_onyx_users(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
)
|
||||
if use_onyx_users
|
||||
else {}
|
||||
)
|
||||
|
||||
all_group_ids = set(group_to_member_email_map.keys()) | set(
|
||||
group_to_member_email_map_from_onyx_users.keys()
|
||||
)
|
||||
final_group_to_member_email_map = {}
|
||||
for group_id in all_group_ids:
|
||||
group_member_emails = group_to_member_email_map.get(
|
||||
group_id, set()
|
||||
) | group_to_member_email_map_from_onyx_users.get(group_id, set())
|
||||
final_group_to_member_email_map[group_id] = group_member_emails
|
||||
|
||||
return final_group_to_member_email_map
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
@@ -87,13 +178,12 @@ def confluence_group_sync(
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
group_to_member_email_map = _build_final_group_to_member_email_map(
|
||||
confluence_client, cc_pair.id
|
||||
)
|
||||
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
for group_id, group_member_emails in group_to_member_email_map.items():
|
||||
yield (
|
||||
ExternalUserGroup(
|
||||
id=group_id,
|
||||
|
||||
136
backend/ee/onyx/external_permissions/jira/group_sync.py
Normal file
136
backend/ee/onyx/external_permissions/jira/group_sync.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from jira import JIRA
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
"scoped_token", False
|
||||
)
|
||||
|
||||
if not jira_base_url:
|
||||
raise ValueError("No jira_base_url found in connector config")
|
||||
|
||||
jira_client = build_jira_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
jira_base=jira_base_url,
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
@@ -16,6 +16,10 @@ HolderMap = dict[str, list[Holder]]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_role_id(holder: Holder) -> str | None:
|
||||
return holder.get("value") or holder.get("parameter")
|
||||
|
||||
|
||||
def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
"""
|
||||
A "Holder" in JIRA is a person / entity who "holds" the corresponding permission.
|
||||
@@ -110,80 +114,137 @@ def _get_user_emails(user_holders: list[Holder]) -> list[str]:
|
||||
return emails
|
||||
|
||||
|
||||
def _get_user_emails_from_project_roles(
|
||||
def _get_user_emails_and_groups_from_project_roles(
|
||||
jira_client: JIRA,
|
||||
jira_project: str,
|
||||
project_role_holders: list[Holder],
|
||||
) -> list[str]:
|
||||
# NOTE (@raunakab) a `parallel_yield` may be helpful here...?
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Get user emails and group names from project roles.
|
||||
Returns a tuple of (emails, group_names).
|
||||
"""
|
||||
# Get role IDs - Cloud uses "value", Data Center uses "parameter"
|
||||
role_ids = []
|
||||
for holder in project_role_holders:
|
||||
role_id = _get_role_id(holder)
|
||||
if role_id:
|
||||
role_ids.append(role_id)
|
||||
else:
|
||||
logger.warning(f"No value or parameter in projectRole holder: {holder}")
|
||||
|
||||
roles = [
|
||||
jira_client.project_role(project=jira_project, id=project_role_holder["value"])
|
||||
for project_role_holder in project_role_holders
|
||||
if "value" in project_role_holder
|
||||
jira_client.project_role(project=jira_project, id=role_id)
|
||||
for role_id in role_ids
|
||||
]
|
||||
|
||||
emails = []
|
||||
groups = []
|
||||
|
||||
for role in roles:
|
||||
if not hasattr(role, "actors"):
|
||||
logger.warning(f"Project role {role} has no actors attribute")
|
||||
continue
|
||||
|
||||
for actor in role.actors:
|
||||
if not hasattr(actor, "actorUser") or not hasattr(
|
||||
actor.actorUser, "accountId"
|
||||
):
|
||||
# Handle group actors
|
||||
if hasattr(actor, "actorGroup"):
|
||||
group_name = getattr(actor.actorGroup, "name", None) or getattr(
|
||||
actor.actorGroup, "displayName", None
|
||||
)
|
||||
if group_name:
|
||||
groups.append(group_name)
|
||||
continue
|
||||
|
||||
user = jira_client.user(id=actor.actorUser.accountId)
|
||||
if not hasattr(user, "accountType") or user.accountType != "atlassian":
|
||||
# Handle user actors
|
||||
if hasattr(actor, "actorUser"):
|
||||
account_id = getattr(actor.actorUser, "accountId", None)
|
||||
if not account_id:
|
||||
logger.error(f"No accountId in actorUser: {actor.actorUser}")
|
||||
continue
|
||||
|
||||
user = jira_client.user(id=account_id)
|
||||
if not hasattr(user, "accountType") or user.accountType != "atlassian":
|
||||
logger.info(
|
||||
f"Skipping user {account_id} because it is not an atlassian user"
|
||||
)
|
||||
continue
|
||||
|
||||
if not hasattr(user, "emailAddress"):
|
||||
msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}"
|
||||
if hasattr(user, "displayName"):
|
||||
msg += f" {actor.displayName=}"
|
||||
logger.warning(msg)
|
||||
continue
|
||||
|
||||
emails.append(user.emailAddress)
|
||||
continue
|
||||
|
||||
if not hasattr(user, "emailAddress"):
|
||||
msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}"
|
||||
if hasattr(user, "displayName"):
|
||||
msg += f" {actor.displayName=}"
|
||||
logger.warn(msg)
|
||||
continue
|
||||
logger.debug(f"Skipping actor type: {actor}")
|
||||
|
||||
emails.append(user.emailAddress)
|
||||
|
||||
return emails
|
||||
return emails, groups
|
||||
|
||||
|
||||
def _build_external_access_from_holder_map(
|
||||
jira_client: JIRA, jira_project: str, holder_map: HolderMap
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
# Note:
|
||||
If the `holder_map` contains an instance of "anyone", then this is a public JIRA project.
|
||||
Otherwise, we fetch the "projectRole"s (i.e., the user-groups in JIRA speak), and the user emails.
|
||||
"""
|
||||
Build ExternalAccess from the holder map.
|
||||
|
||||
Holder types handled:
|
||||
- "anyone": Public project, anyone can access
|
||||
- "applicationRole": All users with a Jira license can access (treated as public)
|
||||
- "user": Specific users with access
|
||||
- "projectRole": Project roles containing users and/or groups
|
||||
- "group": Groups directly assigned in the permission scheme
|
||||
"""
|
||||
# Public access - anyone can view
|
||||
if "anyone" in holder_map:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(), external_user_group_ids=set(), is_public=True
|
||||
)
|
||||
|
||||
# applicationRole means all users with a Jira license can access - treat as public
|
||||
if "applicationRole" in holder_map:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(), external_user_group_ids=set(), is_public=True
|
||||
)
|
||||
|
||||
# Get emails from explicit user holders
|
||||
user_emails = (
|
||||
_get_user_emails(user_holders=holder_map["user"])
|
||||
if "user" in holder_map
|
||||
else []
|
||||
)
|
||||
project_role_user_emails = (
|
||||
_get_user_emails_from_project_roles(
|
||||
jira_client=jira_client,
|
||||
jira_project=jira_project,
|
||||
project_role_holders=holder_map["projectRole"],
|
||||
|
||||
# Get emails and groups from project roles
|
||||
project_role_user_emails: list[str] = []
|
||||
project_role_groups: list[str] = []
|
||||
if "projectRole" in holder_map:
|
||||
project_role_user_emails, project_role_groups = (
|
||||
_get_user_emails_and_groups_from_project_roles(
|
||||
jira_client=jira_client,
|
||||
jira_project=jira_project,
|
||||
project_role_holders=holder_map["projectRole"],
|
||||
)
|
||||
)
|
||||
if "projectRole" in holder_map
|
||||
else []
|
||||
)
|
||||
|
||||
# Get groups directly assigned in permission scheme (common in Data Center)
|
||||
# Format: {'type': 'group', 'parameter': 'group-name', 'expand': 'group'}
|
||||
direct_groups: list[str] = []
|
||||
if "group" in holder_map:
|
||||
for group_holder in holder_map["group"]:
|
||||
group_name = _get_role_id(group_holder)
|
||||
if group_name:
|
||||
direct_groups.append(group_name)
|
||||
else:
|
||||
logger.error(f"No parameter/value in group holder: {group_holder}")
|
||||
|
||||
external_user_emails = set(user_emails + project_role_user_emails)
|
||||
external_user_group_ids = set(project_role_groups + direct_groups)
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=external_user_emails,
|
||||
external_user_group_ids=set(),
|
||||
external_user_group_ids=external_user_group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
@@ -197,9 +258,11 @@ def get_project_permissions(
|
||||
)
|
||||
|
||||
if not hasattr(project_permissions, "permissions"):
|
||||
logger.error(f"Project {jira_project} has no permissions attribute")
|
||||
return None
|
||||
|
||||
if not isinstance(project_permissions.permissions, list):
|
||||
logger.error(f"Project {jira_project} permissions is not a list")
|
||||
return None
|
||||
|
||||
holder_map = _build_holder_map(permissions=project_permissions.permissions)
|
||||
|
||||
@@ -15,6 +15,7 @@ from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -511,8 +512,8 @@ def get_external_access_from_sharepoint(
|
||||
f"Failed to get SharePoint list item ID for item {drive_item.id}"
|
||||
)
|
||||
|
||||
if drive_name == "Shared Documents":
|
||||
drive_name = "Documents"
|
||||
if drive_name in SHARED_DOCUMENTS_MAP_REVERSE:
|
||||
drive_name = SHARED_DOCUMENTS_MAP_REVERSE[drive_name]
|
||||
|
||||
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
|
||||
item_id
|
||||
|
||||
@@ -11,6 +11,7 @@ from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
@@ -23,6 +24,7 @@ from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
|
||||
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
|
||||
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
@@ -110,6 +112,11 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
doc_sync_func=jira_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=JIRA_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=jira_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=True,
|
||||
),
|
||||
),
|
||||
# Groups are not needed for Slack.
|
||||
# All channel access is done at the individual user level.
|
||||
|
||||
@@ -8,12 +8,10 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -24,7 +22,6 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -69,9 +66,9 @@ def handle_simplified_chat_message(
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message, _ = create_chat_chain(
|
||||
parent_message = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
)[-1]
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
@@ -168,8 +165,6 @@ def handle_send_message_simple_with_history(
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
max_history_tokens = int(llm.config.max_input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
@@ -188,17 +183,6 @@ def handle_send_message_simple_with_history(
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=msg_history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = req.query_override or thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
@@ -216,7 +200,7 @@ def handle_send_message_simple_with_history(
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=rephrased_query,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
|
||||
@@ -8,10 +8,29 @@ from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
from onyx.server.query_and_chat.streaming_models import SubQuestionIdentifier
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
message: str
|
||||
slack_bot_categories: list[str]
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
@@ -71,17 +90,17 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@@ -127,12 +146,3 @@ class AgentSubQuery(SubQuestionIdentifier):
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class StandardAnswerRequest(BaseModel):
|
||||
message: str
|
||||
slack_bot_categories: list[str]
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.task_utils import construct_query_history_report_name
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import FileType
|
||||
@@ -123,10 +123,9 @@ def snapshot_from_chat_session(
|
||||
) -> ChatSessionSnapshot | None:
|
||||
try:
|
||||
# Older chats may not have the right structure
|
||||
last_message, messages = create_chat_chain(
|
||||
messages = create_chat_history_chain(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
messages.append(last_message)
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
|
||||
@@ -38,10 +38,8 @@ from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import get_anthropic_model_names
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_VISIBLE_MODEL_NAMES
|
||||
from onyx.llm.llm_provider_options import get_openai_model_names
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
@@ -275,7 +273,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=name,
|
||||
is_visible=name in ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for name in get_anthropic_model_names()
|
||||
@@ -302,10 +300,10 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
model_configurations=[
|
||||
ModelConfigurationUpsertRequest(
|
||||
name=model_name,
|
||||
is_visible=model_name in OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
is_visible=False,
|
||||
max_input_tokens=None,
|
||||
)
|
||||
for model_name in OPEN_AI_MODEL_NAMES
|
||||
for model_name in get_openai_model_names()
|
||||
],
|
||||
api_key_changed=True,
|
||||
)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import FunctionMessage
|
||||
|
||||
from onyx.llm.message_types import AssistantMessage
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import FunctionCall
|
||||
from onyx.llm.message_types import SystemMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.message_types import ToolMessage
|
||||
from onyx.llm.message_types import UserMessageWithText
|
||||
|
||||
|
||||
HUMAN = "human"
|
||||
SYSTEM = "system"
|
||||
AI = "ai"
|
||||
FUNCTION = "function"
|
||||
|
||||
|
||||
def base_messages_to_chat_completion_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
) -> list[ChatCompletionMessage]:
|
||||
return [_base_message_to_chat_completion_msg(msg) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_chat_completion_msg(
|
||||
msg: BaseMessage,
|
||||
) -> ChatCompletionMessage:
|
||||
if msg.type == HUMAN:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
user_msg: UserMessageWithText = {"role": "user", "content": content}
|
||||
return user_msg
|
||||
if msg.type == SYSTEM:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
system_msg: SystemMessage = {"role": "system", "content": content}
|
||||
return system_msg
|
||||
if msg.type == AI:
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
assistant_msg: AssistantMessage = {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
}
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
assistant_msg["tool_calls"] = [
|
||||
ToolCall(
|
||||
id=tool_call.get("id") or "",
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=tool_call["name"],
|
||||
arguments=json.dumps(tool_call["args"]),
|
||||
),
|
||||
)
|
||||
for tool_call in msg.tool_calls
|
||||
]
|
||||
return assistant_msg
|
||||
if msg.type == FUNCTION:
|
||||
function_message = cast(FunctionMessage, msg)
|
||||
content = (
|
||||
function_message.content
|
||||
if isinstance(function_message.content, str)
|
||||
else str(function_message.content)
|
||||
)
|
||||
tool_msg: ToolMessage = {
|
||||
"role": "tool",
|
||||
"content": content,
|
||||
"tool_call_id": function_message.name or "",
|
||||
}
|
||||
return tool_msg
|
||||
raise ValueError(f"Unexpected message type: {msg.type}")
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
|
||||
|
||||
class ToolCallStreamItem(BaseModel):
|
||||
call_id: str | None = None
|
||||
|
||||
id: str | None = None
|
||||
|
||||
name: str | None = None
|
||||
|
||||
arguments: str | None = None
|
||||
|
||||
type: Literal["function_call"] = "function_call"
|
||||
|
||||
index: int | None = None
|
||||
|
||||
|
||||
class ToolCallOutputStreamItem(BaseModel):
|
||||
call_id: str | None = None
|
||||
|
||||
output: Any
|
||||
|
||||
type: Literal["function_call_output"] = "function_call_output"
|
||||
|
||||
|
||||
RunItemStreamEventDetails: TypeAlias = ToolCallStreamItem | ToolCallOutputStreamItem
|
||||
|
||||
|
||||
class RunItemStreamEvent(BaseModel):
|
||||
type: Literal[
|
||||
"message_start",
|
||||
"message_done",
|
||||
"reasoning_start",
|
||||
"reasoning_done",
|
||||
"tool_call",
|
||||
"tool_call_output",
|
||||
]
|
||||
details: RunItemStreamEventDetails | None = None
|
||||
|
||||
|
||||
StreamEvent: TypeAlias = ModelResponseStream | RunItemStreamEvent
|
||||
@@ -1,365 +0,0 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import onyx.tracing.framework._error_tracing as _error_tracing
|
||||
from onyx.agents.agent_framework.models import RunItemStreamEvent
|
||||
from onyx.agents.agent_framework.models import StreamEvent
|
||||
from onyx.agents.agent_framework.models import ToolCallOutputStreamItem
|
||||
from onyx.agents.agent_framework.models import ToolCallStreamItem
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.message_types import ChatCompletionMessage
|
||||
from onyx.llm.message_types import ToolCall
|
||||
from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.tools.tool import RunContextWrapper
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tracing.framework.create import agent_span
|
||||
from onyx.tracing.framework.create import function_span
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.tracing.framework.spans import SpanError
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
stream: Iterator[StreamEvent]
|
||||
new_messages_stateful: list[ChatCompletionMessage]
|
||||
|
||||
|
||||
def _serialize_tool_output(output: Any) -> str:
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
try:
|
||||
return json.dumps(output)
|
||||
except TypeError:
|
||||
return str(output)
|
||||
|
||||
|
||||
def _parse_tool_calls_from_message_content(
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse JSON content that represents tool call instructions."""
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
if isinstance(parsed_content, dict):
|
||||
candidates = [parsed_content]
|
||||
elif isinstance(parsed_content, list):
|
||||
candidates = [item for item in parsed_content if isinstance(item, dict)]
|
||||
else:
|
||||
return []
|
||||
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for candidate in candidates:
|
||||
name = candidate.get("name")
|
||||
arguments = candidate.get("arguments")
|
||||
|
||||
if not isinstance(name, str) or arguments is None:
|
||||
continue
|
||||
|
||||
if not isinstance(arguments, dict):
|
||||
continue
|
||||
|
||||
call_id = candidate.get("id")
|
||||
arguments_str = json.dumps(arguments)
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
}
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
content_parts: list[str],
|
||||
structured_response_format: dict | None,
|
||||
next_synthetic_tool_call_id: Callable[[], str],
|
||||
) -> None:
|
||||
"""Populate tool_calls_in_progress when a non-tool-calling LLM returns JSON content describing tool calls."""
|
||||
if tool_calls_in_progress or not content_parts or structured_response_format:
|
||||
return
|
||||
|
||||
tool_calls_from_content = _parse_tool_calls_from_message_content(
|
||||
"".join(content_parts)
|
||||
)
|
||||
|
||||
if not tool_calls_from_content:
|
||||
return
|
||||
|
||||
content_parts.clear()
|
||||
|
||||
for index, tool_call_data in enumerate(tool_calls_from_content):
|
||||
call_id = tool_call_data["id"] or next_synthetic_tool_call_id()
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": call_id,
|
||||
"name": tool_call_data["name"],
|
||||
"arguments": tool_call_data["arguments"],
|
||||
}
|
||||
|
||||
|
||||
def _update_tool_call_with_delta(
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]],
|
||||
tool_call_delta: Any,
|
||||
) -> None:
|
||||
index = tool_call_delta.index
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call_delta.id:
|
||||
tool_calls_in_progress[index]["id"] = tool_call_delta.id
|
||||
|
||||
if tool_call_delta.function:
|
||||
if tool_call_delta.function.name:
|
||||
tool_calls_in_progress[index]["name"] = tool_call_delta.function.name
|
||||
|
||||
if tool_call_delta.function.arguments:
|
||||
tool_calls_in_progress[index][
|
||||
"arguments"
|
||||
] += tool_call_delta.function.arguments
|
||||
|
||||
|
||||
def query(
|
||||
llm_with_default_settings: LLM,
|
||||
messages: LanguageModelInput,
|
||||
tools: Sequence[Tool],
|
||||
context: Any,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> QueryResult:
|
||||
tool_definitions = [tool.tool_definition() for tool in tools]
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
new_messages_stateful: list[ChatCompletionMessage] = []
|
||||
|
||||
current_span = agent_span(
|
||||
name="agent_framework_query",
|
||||
output_type="dict" if structured_response_format else "str",
|
||||
)
|
||||
current_span.start(mark_as_current=True)
|
||||
current_span.span_data.tools = [t.name for t in tools]
|
||||
|
||||
def stream_generator() -> Iterator[StreamEvent]:
|
||||
message_started = False
|
||||
reasoning_started = False
|
||||
|
||||
tool_calls_in_progress: dict[int, dict[str, Any]] = {}
|
||||
|
||||
content_parts: list[str] = []
|
||||
|
||||
synthetic_tool_call_counter = 0
|
||||
|
||||
def _next_synthetic_tool_call_id() -> str:
|
||||
nonlocal synthetic_tool_call_counter
|
||||
call_id = f"synthetic_tool_call_{synthetic_tool_call_counter}"
|
||||
synthetic_tool_call_counter += 1
|
||||
return call_id
|
||||
|
||||
with generation_span( # type: ignore[misc]
|
||||
model=llm_with_default_settings.config.model_name,
|
||||
model_config={
|
||||
"base_url": str(llm_with_default_settings.config.api_base or ""),
|
||||
"model_impl": "litellm",
|
||||
},
|
||||
) as span_generation:
|
||||
# Only set input if messages is a sequence (not a string)
|
||||
# ChatCompletionMessage TypedDicts are compatible with Mapping[str, Any] at runtime
|
||||
if isinstance(messages, Sequence) and not isinstance(messages, str):
|
||||
# Convert ChatCompletionMessage sequence to Sequence[Mapping[str, Any]]
|
||||
span_generation.span_data.input = [dict(msg) for msg in messages] # type: ignore[assignment]
|
||||
for chunk in llm_with_default_settings.stream(
|
||||
prompt=messages,
|
||||
tools=tool_definitions,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=structured_response_format,
|
||||
):
|
||||
assert isinstance(chunk, ModelResponseStream)
|
||||
usage = getattr(chunk, "usage", None)
|
||||
if usage:
|
||||
span_generation.span_data.usage = {
|
||||
"input_tokens": usage.prompt_tokens,
|
||||
"output_tokens": usage.completion_tokens,
|
||||
"cache_read_input_tokens": usage.cache_read_input_tokens,
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
|
||||
delta = chunk.choice.delta
|
||||
finish_reason = chunk.choice.finish_reason
|
||||
|
||||
if delta.reasoning_content:
|
||||
if not reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_start")
|
||||
reasoning_started = True
|
||||
|
||||
if delta.content:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
content_parts.append(delta.content)
|
||||
if not message_started:
|
||||
yield RunItemStreamEvent(type="message_start")
|
||||
message_started = True
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(
|
||||
tool_calls_in_progress, tool_call_delta
|
||||
)
|
||||
|
||||
yield chunk
|
||||
|
||||
if not finish_reason:
|
||||
continue
|
||||
|
||||
if reasoning_started:
|
||||
yield RunItemStreamEvent(type="reasoning_done")
|
||||
reasoning_started = False
|
||||
if message_started:
|
||||
yield RunItemStreamEvent(type="message_done")
|
||||
message_started = False
|
||||
|
||||
if tool_choice != "none":
|
||||
_try_convert_content_to_tool_calls_for_non_tool_calling_llms(
|
||||
tool_calls_in_progress,
|
||||
content_parts,
|
||||
structured_response_format,
|
||||
_next_synthetic_tool_call_id,
|
||||
)
|
||||
|
||||
if content_parts:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "".join(content_parts),
|
||||
}
|
||||
)
|
||||
span_generation.span_data.output = new_messages_stateful
|
||||
|
||||
# Execute tool calls outside of the stream loop and generation_span
|
||||
if tool_calls_in_progress:
|
||||
sorted_tool_calls = sorted(tool_calls_in_progress.items())
|
||||
|
||||
# Build tool calls for the message and execute tools
|
||||
assistant_tool_calls: list[ToolCall] = []
|
||||
tool_outputs: dict[str, str] = {}
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
name = tool_call_data["name"]
|
||||
arguments_str = tool_call_data["arguments"]
|
||||
|
||||
if call_id is None or name is None:
|
||||
continue
|
||||
|
||||
assistant_tool_calls.append(
|
||||
{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments_str,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call",
|
||||
details=ToolCallStreamItem(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
),
|
||||
)
|
||||
|
||||
if name in tools_by_name:
|
||||
tool = tools_by_name[name]
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
run_context = RunContextWrapper(context=context)
|
||||
|
||||
# TODO: Instead of executing sequentially, execute in parallel
|
||||
# In practice, it's not a must right now since we don't use parallel
|
||||
# tool calls, so kicking the can down the road for now.
|
||||
with function_span(tool.name) as span_fn:
|
||||
span_fn.span_data.input = arguments
|
||||
try:
|
||||
output = tool.run_v2(run_context, **arguments)
|
||||
tool_outputs[call_id] = _serialize_tool_output(output)
|
||||
span_fn.span_data.output = output
|
||||
except Exception as e:
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool",
|
||||
data={"tool_name": tool.name, "error": str(e)},
|
||||
)
|
||||
)
|
||||
# Treat the error as the tool output so the framework can continue
|
||||
error_output = f"Error: {str(e)}"
|
||||
tool_outputs[call_id] = error_output
|
||||
output = error_output
|
||||
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=output,
|
||||
),
|
||||
)
|
||||
else:
|
||||
not_found_output = f"Tool {name} not found"
|
||||
tool_outputs[call_id] = _serialize_tool_output(not_found_output)
|
||||
yield RunItemStreamEvent(
|
||||
type="tool_call_output",
|
||||
details=ToolCallOutputStreamItem(
|
||||
call_id=call_id,
|
||||
output=not_found_output,
|
||||
),
|
||||
)
|
||||
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": assistant_tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
for _, tool_call_data in sorted_tool_calls:
|
||||
call_id = tool_call_data["id"]
|
||||
|
||||
if call_id in tool_outputs:
|
||||
new_messages_stateful.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": tool_outputs[call_id],
|
||||
"tool_call_id": call_id,
|
||||
}
|
||||
)
|
||||
current_span.finish(reset_current=True)
|
||||
|
||||
return QueryResult(
|
||||
stream=stream_generator(),
|
||||
new_messages_stateful=new_messages_stateful,
|
||||
)
|
||||
@@ -1,167 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.agents.agent_sdk.message_types import AssistantMessageWithContent
|
||||
from onyx.agents.agent_sdk.message_types import ImageContent
|
||||
from onyx.agents.agent_sdk.message_types import InputTextContent
|
||||
from onyx.agents.agent_sdk.message_types import SystemMessage
|
||||
from onyx.agents.agent_sdk.message_types import UserMessage
|
||||
|
||||
|
||||
# TODO: Currently, we only support native API input for images. For other
|
||||
# files, we process the content and share it as text in the message. In
|
||||
# the future, we might support native file uploads for other types of files.
|
||||
def base_messages_to_agent_sdk_msgs(
|
||||
msgs: Sequence[BaseMessage],
|
||||
is_responses_api: bool,
|
||||
) -> list[AgentSDKMessage]:
|
||||
return [_base_message_to_agent_sdk_msg(msg, is_responses_api) for msg in msgs]
|
||||
|
||||
|
||||
def _base_message_to_agent_sdk_msg(
|
||||
msg: BaseMessage, is_responses_api: bool
|
||||
) -> AgentSDKMessage:
|
||||
message_type_to_agent_sdk_role = {
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
"ai": "assistant",
|
||||
}
|
||||
role = message_type_to_agent_sdk_role[msg.type]
|
||||
|
||||
# Convert content to Agent SDK format
|
||||
content = msg.content
|
||||
|
||||
if isinstance(content, str):
|
||||
# For system/user/assistant messages, use InputTextContent
|
||||
if role in ("system", "user"):
|
||||
input_text_content: list[InputTextContent | ImageContent] = [
|
||||
InputTextContent(type="input_text", text=content)
|
||||
]
|
||||
if role == "system":
|
||||
# SystemMessage only accepts InputTextContent
|
||||
system_msg: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": [InputTextContent(type="input_text", text=content)],
|
||||
}
|
||||
return system_msg
|
||||
else: # user
|
||||
user_msg: UserMessage = {
|
||||
"role": "user",
|
||||
"content": input_text_content,
|
||||
}
|
||||
return user_msg
|
||||
else: # assistant
|
||||
assistant_msg: AssistantMessageWithContent
|
||||
if is_responses_api:
|
||||
from onyx.agents.agent_sdk.message_types import OutputTextContent
|
||||
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": [OutputTextContent(type="output_text", text=content)],
|
||||
}
|
||||
else:
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": [InputTextContent(type="input_text", text=content)],
|
||||
}
|
||||
return assistant_msg
|
||||
elif isinstance(content, list):
|
||||
# For lists, we need to process based on the role
|
||||
if role == "assistant":
|
||||
# For responses API, use OutputTextContent; otherwise use InputTextContent
|
||||
assistant_content: list[InputTextContent | OutputTextContent] = []
|
||||
|
||||
if is_responses_api:
|
||||
from onyx.agents.agent_sdk.message_types import OutputTextContent
|
||||
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
assistant_content.append(
|
||||
OutputTextContent(type="output_text", text=item)
|
||||
)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
assistant_content.append(
|
||||
OutputTextContent(
|
||||
type="output_text", text=item.get("text", "")
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
|
||||
)
|
||||
else:
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
assistant_content.append(
|
||||
InputTextContent(type="input_text", text=item)
|
||||
)
|
||||
elif isinstance(item, dict) and item.get("type") == "text":
|
||||
assistant_content.append(
|
||||
InputTextContent(
|
||||
type="input_text", text=item.get("text", "")
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type for assistant message: {type(item)}. Item: {item}"
|
||||
)
|
||||
|
||||
assistant_msg_list: AssistantMessageWithContent = {
|
||||
"role": "assistant",
|
||||
"content": assistant_content,
|
||||
}
|
||||
return assistant_msg_list
|
||||
else: # system or user - use InputTextContent
|
||||
input_content: list[InputTextContent | ImageContent] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
input_content.append(InputTextContent(type="input_text", text=item))
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
input_content.append(
|
||||
InputTextContent(
|
||||
type="input_text", text=item.get("text", "")
|
||||
)
|
||||
)
|
||||
elif item_type == "image_url":
|
||||
# Convert image_url to input_image format
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url", "")
|
||||
else:
|
||||
url = image_url
|
||||
input_content.append(
|
||||
ImageContent(
|
||||
type="input_image", image_url=url, detail="auto"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected item type: {item_type}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected item type: {type(item)}. Item: {item}"
|
||||
)
|
||||
|
||||
if role == "system":
|
||||
# SystemMessage only accepts InputTextContent (no images)
|
||||
text_only_content = [
|
||||
c for c in input_content if c["type"] == "input_text"
|
||||
]
|
||||
system_msg_list: SystemMessage = {
|
||||
"role": "system",
|
||||
"content": text_only_content, # type: ignore[typeddict-item]
|
||||
}
|
||||
return system_msg_list
|
||||
else: # user
|
||||
user_msg_list: UserMessage = {
|
||||
"role": "user",
|
||||
"content": input_content,
|
||||
}
|
||||
return user_msg_list
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {type(content)}. Content: {content}"
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Strongly typed message structures for Agent SDK messages."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class InputTextContent(TypedDict):
|
||||
type: Literal["input_text"]
|
||||
text: str
|
||||
|
||||
|
||||
class OutputTextContent(TypedDict):
|
||||
type: Literal["output_text"]
|
||||
text: str
|
||||
|
||||
|
||||
TextContent = InputTextContent | OutputTextContent
|
||||
|
||||
|
||||
class ImageContent(TypedDict):
|
||||
type: Literal["input_image"]
|
||||
image_url: str
|
||||
detail: str
|
||||
|
||||
|
||||
# Tool call structures
|
||||
class ToolCallFunction(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
# Message types
|
||||
class SystemMessage(TypedDict):
|
||||
role: Literal["system"]
|
||||
content: list[InputTextContent] # System messages use input text
|
||||
|
||||
|
||||
class UserMessage(TypedDict):
|
||||
role: Literal["user"]
|
||||
content: list[
|
||||
InputTextContent | ImageContent
|
||||
] # User messages use input text or images
|
||||
|
||||
|
||||
class AssistantMessageWithContent(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
content: list[
|
||||
InputTextContent | OutputTextContent
|
||||
] # Assistant messages use output_text for responses API compatibility
|
||||
|
||||
|
||||
class AssistantMessageWithToolCalls(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
tool_calls: list[ToolCall]
|
||||
|
||||
|
||||
class AssistantMessageDuringAgentRun(TypedDict):
|
||||
role: Literal["assistant"]
|
||||
id: str
|
||||
content: (
|
||||
list[InputTextContent | OutputTextContent] | list[ToolCall]
|
||||
) # Assistant runtime messages receive output_text from agents SDK for responses API compatibility
|
||||
status: Literal["completed", "failed", "in_progress"]
|
||||
type: Literal["message"]
|
||||
|
||||
|
||||
class ToolMessage(TypedDict):
|
||||
role: Literal["tool"]
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class FunctionCallMessage(TypedDict):
|
||||
"""Agent SDK function call message format."""
|
||||
|
||||
type: Literal["function_call"]
|
||||
id: NotRequired[str]
|
||||
call_id: str
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCallOutputMessage(TypedDict):
|
||||
"""Agent SDK function call output message format."""
|
||||
|
||||
type: Literal["function_call_output"]
|
||||
call_id: str
|
||||
output: str
|
||||
|
||||
|
||||
class SummaryText(TypedDict):
|
||||
"""Summary text item in reasoning messages."""
|
||||
|
||||
text: str
|
||||
type: Literal["summary_text"]
|
||||
|
||||
|
||||
class ReasoningMessage(TypedDict):
|
||||
"""Agent SDK reasoning message format."""
|
||||
|
||||
id: str
|
||||
type: Literal["reasoning"]
|
||||
summary: list[SummaryText]
|
||||
|
||||
|
||||
# Union type for all Agent SDK messages
|
||||
AgentSDKMessage = (
|
||||
SystemMessage
|
||||
| UserMessage
|
||||
| AssistantMessageWithContent
|
||||
| AssistantMessageWithToolCalls
|
||||
| AssistantMessageDuringAgentRun
|
||||
| ToolMessage
|
||||
| FunctionCallMessage
|
||||
| FunctionCallOutputMessage
|
||||
| ReasoningMessage
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from agents.models.openai_responses import Converter as OpenAIResponsesConverter
|
||||
|
||||
|
||||
# TODO: I am very sad that I have to monkey patch this :(
|
||||
# Basically, OpenAI agents sdk doesn't convert the tool choice correctly
|
||||
# when they have a built-in tool in their framework, like they do for web_search
|
||||
# and image_generation.
|
||||
# Going to open up a thread with OpenAI agents team to see what they recommend
|
||||
# or what we can fix.
|
||||
# A discussion is warranted, but we likely want to just write our own LitellmModel for
|
||||
# the OpenAI agents SDK since they probably don't really care about Litellm and will
|
||||
# prioritize functionality for their own models.
|
||||
def monkey_patch_convert_tool_choice_to_ignore_openai_hosted_web_search() -> None:
|
||||
if (
|
||||
getattr(OpenAIResponsesConverter.convert_tool_choice, "__name__", "")
|
||||
== "_patched_convert_tool_choice"
|
||||
):
|
||||
return
|
||||
|
||||
orig_func = OpenAIResponsesConverter.convert_tool_choice.__func__ # type: ignore[attr-defined]
|
||||
|
||||
def _patched_convert_tool_choice(cls: type, tool_choice: Any) -> Any:
|
||||
# Handle OpenAI hosted tools that we have custom implementations for
|
||||
# Without this patch, the library uses special formatting that breaks our custom tools
|
||||
# See: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice-hosted_tool-type
|
||||
if tool_choice == "web_search":
|
||||
return {"type": "function", "name": "web_search"}
|
||||
if tool_choice == "image_generation":
|
||||
return {"type": "function", "name": "image_generation"}
|
||||
return orig_func(cls, tool_choice)
|
||||
|
||||
OpenAIResponsesConverter.convert_tool_choice = classmethod( # type: ignore[method-assign, assignment]
|
||||
_patched_convert_tool_choice
|
||||
)
|
||||
@@ -1,178 +0,0 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from agents import Agent
|
||||
from agents import RunResultStreaming
|
||||
from agents import TContext
|
||||
from agents.run import Runner
|
||||
|
||||
from onyx.agents.agent_sdk.message_types import AgentSDKMessage
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SyncAgentStream(Generic[T]):
|
||||
"""
|
||||
Convert an async streamed run into a sync iterator with cooperative cancellation.
|
||||
Runs the Agent in a background thread.
|
||||
|
||||
Usage:
|
||||
adapter = SyncStreamAdapter(
|
||||
agent=agent,
|
||||
input=input,
|
||||
context=context,
|
||||
max_turns=100,
|
||||
queue_maxsize=0, # optional backpressure
|
||||
)
|
||||
for ev in adapter: # sync iteration
|
||||
...
|
||||
# or cancel from elsewhere:
|
||||
adapter.cancel()
|
||||
"""
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
agent: Agent,
|
||||
input: Sequence[AgentSDKMessage],
|
||||
context: TContext | None = None,
|
||||
max_turns: int = 100,
|
||||
queue_maxsize: int = 0,
|
||||
) -> None:
|
||||
self._agent = agent
|
||||
self._input = input
|
||||
self._context = context
|
||||
self._max_turns = max_turns
|
||||
|
||||
self._q: "queue.Queue[object]" = queue.Queue(maxsize=queue_maxsize)
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self.streamed: RunResultStreaming | None = None
|
||||
self._exc: Optional[BaseException] = None
|
||||
self._cancel_requested = threading.Event()
|
||||
self._started = threading.Event()
|
||||
self._done = threading.Event()
|
||||
|
||||
self._start_thread()
|
||||
|
||||
# ---------- public sync API ----------
|
||||
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
try:
|
||||
while True:
|
||||
item = self._q.get()
|
||||
if item is self._SENTINEL:
|
||||
# If the consumer thread raised, surface it now
|
||||
if self._exc is not None:
|
||||
raise self._exc
|
||||
# Normal completion
|
||||
return
|
||||
yield item # type: ignore[misc,return-value]
|
||||
finally:
|
||||
# Ensure we fully clean up whether we exited due to exception,
|
||||
# StopIteration, or external cancel.
|
||||
self.close()
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""
|
||||
Cooperatively cancel the underlying streamed run and shut down.
|
||||
Safe to call multiple times and from any thread.
|
||||
"""
|
||||
self._cancel_requested.set()
|
||||
loop = self._loop
|
||||
streamed = self.streamed
|
||||
if loop is not None and streamed is not None and not self._done.is_set():
|
||||
loop.call_soon_threadsafe(streamed.cancel)
|
||||
return True
|
||||
return False
|
||||
|
||||
def close(self, *, wait: bool = True) -> None:
|
||||
"""Idempotent shutdown."""
|
||||
self.cancel()
|
||||
# ask the loop to stop if it's still running
|
||||
loop = self._loop
|
||||
if loop is not None and loop.is_running():
|
||||
try:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
except Exception:
|
||||
pass
|
||||
# join the thread
|
||||
if wait and self._thread is not None and self._thread.is_alive():
|
||||
self._thread.join(timeout=5.0)
|
||||
|
||||
# ---------- internals ----------
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
t = run_in_background(self._thread_main)
|
||||
self._thread = t
|
||||
# Optionally wait until the loop/worker is started so .cancel() is safe soon after init
|
||||
self._started.wait(timeout=1.0)
|
||||
|
||||
def _thread_main(self) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
self._loop = loop
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
async def worker() -> None:
|
||||
try:
|
||||
# Start the streamed run inside the loop thread
|
||||
self.streamed = Runner.run_streamed(
|
||||
self._agent,
|
||||
self._input, # type: ignore[arg-type]
|
||||
context=self._context,
|
||||
max_turns=self._max_turns,
|
||||
)
|
||||
|
||||
# If cancel was requested before we created _streamed, honor it now
|
||||
if self._cancel_requested.is_set():
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
|
||||
# Consume async events and forward into the thread-safe queue
|
||||
async for ev in self.streamed.stream_events():
|
||||
# Early exit if a late cancel arrives
|
||||
if self._cancel_requested.is_set():
|
||||
# Try to cancel gracefully; don't break until cancel takes effect
|
||||
try:
|
||||
await self.streamed.cancel() # type: ignore[func-returns-value]
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
# This put() may block if queue_maxsize > 0 (backpressure)
|
||||
self._q.put(ev)
|
||||
|
||||
except BaseException as e:
|
||||
# Save exception to surface on the sync iterator side
|
||||
self._exc = e
|
||||
finally:
|
||||
# Signal end-of-stream
|
||||
self._q.put(self._SENTINEL)
|
||||
self._done.set()
|
||||
|
||||
# Mark started and run the worker to completion
|
||||
self._started.set()
|
||||
try:
|
||||
loop.run_until_complete(worker())
|
||||
finally:
|
||||
try:
|
||||
# Drain pending tasks/callbacks safely
|
||||
pending = asyncio.all_tasks(loop=loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
self._loop = None
|
||||
@@ -1,21 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
current_step_nr: int = 1
|
||||
|
||||
|
||||
class SubgraphCoreState(BaseModel):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
@@ -1,62 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
|
||||
|
||||
def parallel_object_source_research_edge(
|
||||
state: SearchSourcesObjectsUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
|
||||
search_objects = state.analysis_objects
|
||||
search_sources = state.analysis_sources
|
||||
|
||||
object_source_combinations = [
|
||||
(object, source) for object in search_objects for source in search_sources
|
||||
]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"research_object_source",
|
||||
ObjectSourceInput(
|
||||
object_source_combination=object_source_combination,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_source_combination in object_source_combinations
|
||||
]
|
||||
|
||||
|
||||
def parallel_object_research_consolidation_edge(
|
||||
state: ObjectResearchInformationUpdate, config: RunnableConfig
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the research for an individual object and source
|
||||
"""
|
||||
cast(GraphConfig, config["metadata"]["config"])
|
||||
object_research_information_results = state.object_research_information_results
|
||||
|
||||
return [
|
||||
Send(
|
||||
"consolidate_object_research",
|
||||
ObjectInformationInput(
|
||||
object_information=object_information,
|
||||
log_messages=[],
|
||||
),
|
||||
)
|
||||
for object_information in object_research_information_results
|
||||
]
|
||||
@@ -1,103 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_research_consolidation_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.edges import (
|
||||
parallel_object_source_research_edge,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
|
||||
search_objects,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
|
||||
research_object_source,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
|
||||
structure_research_by_object,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
|
||||
consolidate_object_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
|
||||
consolidate_research,
|
||||
)
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the knowledge graph search process.
|
||||
"""
|
||||
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
"search_objects",
|
||||
search_objects,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"structure_research_by_source",
|
||||
structure_research_by_object,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"research_object_source",
|
||||
research_object_source,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_object_research",
|
||||
consolidate_object_research,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
"consolidate_research",
|
||||
consolidate_research,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="search_objects")
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="search_objects",
|
||||
path=parallel_object_source_research_edge,
|
||||
path_map=["research_object_source"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="research_object_source",
|
||||
end_key="structure_research_by_source",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="structure_research_by_source",
|
||||
path=parallel_object_research_consolidation_edge,
|
||||
path_map=["consolidate_object_research"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_object_research",
|
||||
end_key="consolidate_research",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="consolidate_research",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
@@ -1,146 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
SearchSourcesObjectsUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def search_objects(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> SearchSourcesObjectsUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_1_instructions = extract_section(
|
||||
instructions, "Agent Step 1:", "Agent Step 2:"
|
||||
)
|
||||
if agent_1_instructions is None:
|
||||
raise ValueError("Agent 1 instructions not found")
|
||||
|
||||
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
|
||||
agent_1_task = extract_section(
|
||||
agent_1_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_1_task is None:
|
||||
raise ValueError("Agent 1 task not found")
|
||||
|
||||
agent_1_independent_sources_str = extract_section(
|
||||
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
|
||||
)
|
||||
if agent_1_independent_sources_str is None:
|
||||
raise ValueError("Agent 1 Independent Research Sources not found")
|
||||
|
||||
document_sources = strings_to_document_sources(
|
||||
[
|
||||
x.strip().lower()
|
||||
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
|
||||
]
|
||||
)
|
||||
|
||||
agent_1_output_objective = extract_section(
|
||||
agent_1_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_1_output_objective is None:
|
||||
raise ValueError("Agent 1 output objective not found")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
|
||||
# Extract objects
|
||||
|
||||
if agent_1_base_data is None:
|
||||
# Retrieve chunks for objects
|
||||
|
||||
retrieved_docs = research(question, search_tool)[:10]
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
document_text=document_texts,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
|
||||
question=question,
|
||||
task=agent_1_task,
|
||||
base_data=agent_1_base_data,
|
||||
objects_of_interest=agent_1_output_objective,
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_extraction_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content)
|
||||
.replace("```json\n", "")
|
||||
.replace("\n```", "")
|
||||
.replace("\n", "")
|
||||
)
|
||||
cleaned_response = cleaned_response.split("OBJECTS:")[1]
|
||||
object_list = [x.strip() for x in cleaned_response.split(";")]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in search_objects: {e}")
|
||||
|
||||
return SearchSourcesObjectsUpdate(
|
||||
analysis_objects=object_list,
|
||||
analysis_sources=document_sources,
|
||||
log_messages=["Agent 1 Task done"],
|
||||
)
|
||||
@@ -1,180 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import research
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectSourceResearchUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def research_object_source(
|
||||
state: ObjectSourceInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectSourceResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
object, document_source = state.object_source_combination
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
try:
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_2_instructions = extract_section(
|
||||
instructions, "Agent Step 2:", "Agent Step 3:"
|
||||
)
|
||||
if agent_2_instructions is None:
|
||||
raise ValueError("Agent 2 instructions not found")
|
||||
|
||||
agent_2_task = extract_section(
|
||||
agent_2_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_2_task is None:
|
||||
raise ValueError("Agent 2 task not found")
|
||||
|
||||
agent_2_time_cutoff = extract_section(
|
||||
agent_2_instructions, "Time Cutoff:", "Research Topics:"
|
||||
)
|
||||
|
||||
agent_2_research_topics = extract_section(
|
||||
agent_2_instructions, "Research Topics:", "Output Objective"
|
||||
)
|
||||
|
||||
agent_2_output_objective = extract_section(
|
||||
agent_2_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_2_output_objective is None:
|
||||
raise ValueError("Agent 2 output objective not found")
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Agent 1 instructions not found or not formatted correctly: {e}"
|
||||
)
|
||||
|
||||
# Populate prompt
|
||||
|
||||
# Retrieve chunks for objects
|
||||
|
||||
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
|
||||
if agent_2_time_cutoff.strip().endswith("d"):
|
||||
try:
|
||||
days = int(agent_2_time_cutoff.strip()[:-1])
|
||||
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
|
||||
days=days
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
|
||||
)
|
||||
else:
|
||||
agent_2_source_start_time = None
|
||||
|
||||
document_sources = [document_source] if document_source else None
|
||||
|
||||
if len(question.strip()) > 0:
|
||||
research_area = f"{question} for {object}"
|
||||
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
|
||||
research_area = f"{agent_2_research_topics} for {object}"
|
||||
else:
|
||||
research_area = object
|
||||
|
||||
retrieved_docs = research(
|
||||
question=research_area,
|
||||
search_tool=search_tool,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=agent_2_source_start_time,
|
||||
)
|
||||
|
||||
# Generate document text
|
||||
|
||||
document_texts_list = []
|
||||
for doc_num, doc in enumerate(retrieved_docs):
|
||||
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
# Built prompt
|
||||
|
||||
today = datetime.now().strftime("%A, %Y-%m-%d")
|
||||
|
||||
dc_object_source_research_prompt = (
|
||||
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
|
||||
today=today,
|
||||
question=question,
|
||||
task=agent_2_task,
|
||||
document_text=document_texts,
|
||||
format=agent_2_output_objective,
|
||||
)
|
||||
.replace("---object---", object)
|
||||
.replace("---source---", document_source.value)
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_source_research_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"source": document_source.value,
|
||||
"research_result": cleaned_response,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in research_object_source: {e}")
|
||||
|
||||
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
|
||||
|
||||
return ObjectSourceResearchUpdate(
|
||||
object_source_research_results=[object_research_results],
|
||||
log_messages=["Agent Step 2 done for one object"],
|
||||
)
|
||||
@@ -1,48 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import (
|
||||
ObjectResearchInformationUpdate,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def structure_research_by_object(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ObjectResearchInformationUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
object_source_research_results = state.object_source_research_results
|
||||
|
||||
object_research_information_results: List[Dict[str, str]] = []
|
||||
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
|
||||
|
||||
for object_source_research in object_source_research_results:
|
||||
object = object_source_research["object"]
|
||||
source = object_source_research["source"]
|
||||
research_result = object_source_research["research_result"]
|
||||
|
||||
object_research_information_results_list[object].append(
|
||||
f"Source: {source}\n{research_result}"
|
||||
)
|
||||
|
||||
for object, information in object_research_information_results_list.items():
|
||||
object_research_information_results.append(
|
||||
{"object": object, "information": "\n".join(information)}
|
||||
)
|
||||
|
||||
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
|
||||
|
||||
return ObjectResearchInformationUpdate(
|
||||
object_research_information_results=object_research_information_results,
|
||||
log_messages=["A3 - Object Research Information structured"],
|
||||
)
|
||||
@@ -1,103 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_object_research(
|
||||
state: ObjectInformationInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> ObjectResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
agent_4_instructions = extract_section(
|
||||
instructions, "Agent Step 4:", "Agent Step 5:"
|
||||
)
|
||||
if agent_4_instructions is None:
|
||||
raise ValueError("Agent 4 instructions not found")
|
||||
agent_4_output_objective = extract_section(
|
||||
agent_4_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_4_output_objective is None:
|
||||
raise ValueError("Agent 4 output objective not found")
|
||||
|
||||
object_information = state.object_information
|
||||
|
||||
object = object_information["object"]
|
||||
information = object_information["information"]
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
|
||||
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
|
||||
question=question,
|
||||
object=object,
|
||||
information=information,
|
||||
format=agent_4_output_objective,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_object_consolidation_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
primary_llm = graph_config.tooling.primary_llm
|
||||
# Grader
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
30,
|
||||
primary_llm.invoke_langchain,
|
||||
prompt=msg,
|
||||
timeout_override=30,
|
||||
max_tokens=300,
|
||||
)
|
||||
|
||||
cleaned_response = str(llm_response.content).replace("```json\n", "")
|
||||
consolidated_information = cleaned_response.split("INFORMATION:")[1]
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_object_research: {e}")
|
||||
|
||||
object_research_results = {
|
||||
"object": object,
|
||||
"research_result": consolidated_information,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
"DivCon Step A4 - Object Research Consolidation - completed for an object"
|
||||
)
|
||||
|
||||
return ObjectResearchUpdate(
|
||||
object_research_results=[object_research_results],
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import MainState
|
||||
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def consolidate_research(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> ResearchUpdate:
|
||||
"""
|
||||
LangGraph node to start the agentic search process.
|
||||
"""
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
search_tool = graph_config.tooling.search_tool
|
||||
|
||||
if search_tool is None or graph_config.inputs.persona is None:
|
||||
raise ValueError("Search tool and persona must be provided for DivCon search")
|
||||
|
||||
# Populate prompt
|
||||
instructions = graph_config.inputs.persona.system_prompt or ""
|
||||
|
||||
try:
|
||||
agent_5_instructions = extract_section(
|
||||
instructions, "Agent Step 5:", "Agent End"
|
||||
)
|
||||
if agent_5_instructions is None:
|
||||
raise ValueError("Agent 5 instructions not found")
|
||||
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
|
||||
agent_5_task = extract_section(
|
||||
agent_5_instructions, "Task:", "Independent Research Sources:"
|
||||
)
|
||||
if agent_5_task is None:
|
||||
raise ValueError("Agent 5 task not found")
|
||||
agent_5_output_objective = extract_section(
|
||||
agent_5_instructions, "Output Objective:"
|
||||
)
|
||||
if agent_5_output_objective is None:
|
||||
raise ValueError("Agent 5 output objective not found")
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Instructions for Agent Step 5 were not properly formatted: {e}"
|
||||
)
|
||||
|
||||
research_result_list = []
|
||||
|
||||
if agent_5_task.strip() == "*concatenate*":
|
||||
object_research_results = state.object_research_results
|
||||
|
||||
for object_research_result in object_research_results:
|
||||
object = object_research_result["object"]
|
||||
research_result = object_research_result["research_result"]
|
||||
research_result_list.append(f"Object: {object}\n\n{research_result}")
|
||||
|
||||
research_results = "\n\n".join(research_result_list)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Only '*concatenate*' is currently supported")
|
||||
|
||||
# Create a prompt for the object consolidation
|
||||
|
||||
if agent_5_base_data is None:
|
||||
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
else:
|
||||
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
|
||||
base_data=agent_5_base_data,
|
||||
text=research_results,
|
||||
format=agent_5_output_objective,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=trim_prompt_piece(
|
||||
config=graph_config.tooling.primary_llm.config,
|
||||
prompt_piece=dc_formatting_prompt,
|
||||
reserved_str="",
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
_ = run_with_timeout(
|
||||
60,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=30,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
logger.debug("DivCon Step A5 - Final Generation - completed")
|
||||
|
||||
return ResearchUpdate(
|
||||
research_results=research_results,
|
||||
log_messages=["Agent Source Consilidation done"],
|
||||
)
|
||||
@@ -1,61 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def research(
|
||||
question: str,
|
||||
search_tool: SearchTool,
|
||||
document_sources: list[DocumentSource] | None = None,
|
||||
time_cutoff: datetime | None = None,
|
||||
) -> list[LlmDoc]:
|
||||
# new db session to avoid concurrency issues
|
||||
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
retrieved_docs: list[LlmDoc] = []
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
document_sources=document_sources,
|
||||
time_cutoff=time_cutoff,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
|
||||
break
|
||||
return retrieved_docs
|
||||
|
||||
|
||||
def extract_section(
|
||||
text: str, start_marker: str, end_marker: str | None = None
|
||||
) -> str | None:
|
||||
"""Extract text between markers, returning None if markers not found"""
|
||||
parts = text.split(start_marker)
|
||||
|
||||
if len(parts) == 1:
|
||||
return None
|
||||
|
||||
after_start = parts[1].strip()
|
||||
|
||||
if not end_marker:
|
||||
return after_start
|
||||
|
||||
extract = after_start.split(end_marker)[0]
|
||||
|
||||
return extract.strip()
|
||||
@@ -1,72 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Dict
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
|
||||
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
|
||||
### States ###
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class SearchSourcesObjectsUpdate(LoggerUpdate):
|
||||
analysis_objects: list[str] = []
|
||||
analysis_sources: list[DocumentSource] = []
|
||||
|
||||
|
||||
class ObjectSourceInput(LoggerUpdate):
|
||||
object_source_combination: tuple[str, DocumentSource]
|
||||
|
||||
|
||||
class ObjectSourceResearchUpdate(LoggerUpdate):
|
||||
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectInformationInput(LoggerUpdate):
|
||||
object_information: Dict[str, str]
|
||||
|
||||
|
||||
class ObjectResearchInformationUpdate(LoggerUpdate):
|
||||
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ObjectResearchUpdate(LoggerUpdate):
|
||||
object_research_results: Annotated[list[Dict[str, str]], add] = []
|
||||
|
||||
|
||||
class ResearchUpdate(LoggerUpdate):
|
||||
research_results: str | None = None
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
ToolChoiceInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
SearchSourcesObjectsUpdate,
|
||||
ObjectSourceResearchUpdate,
|
||||
ObjectResearchInformationUpdate,
|
||||
ObjectResearchUpdate,
|
||||
ResearchUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State - presently not used
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
@@ -1,36 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefinementSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
sub_question_id: str
|
||||
verified: bool
|
||||
answered: bool
|
||||
answer: str
|
||||
|
||||
|
||||
class AgentTimings(BaseModel):
|
||||
base_duration_s: float | None
|
||||
refined_duration_s: float | None
|
||||
full_duration_s: float | None
|
||||
|
||||
|
||||
class AgentBaseMetrics(BaseModel):
|
||||
num_verified_documents_total: int | None
|
||||
num_verified_documents_core: int | None
|
||||
verified_avg_score_core: float | None
|
||||
num_verified_documents_base: int | float | None
|
||||
verified_avg_score_base: float | None = None
|
||||
base_doc_boost_factor: float | None = None
|
||||
support_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentRefinedMetrics(BaseModel):
|
||||
refined_doc_boost_factor: float | None = None
|
||||
refined_question_boost_factor: float | None = None
|
||||
duration_s: float | None = None
|
||||
|
||||
|
||||
class AgentAdditionalMetrics(BaseModel):
|
||||
pass
|
||||
@@ -1,61 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
|
||||
|
||||
def decision_router(state: MainState) -> list[Send | Hashable] | DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("state.tools_used cannot be empty")
|
||||
|
||||
# next_tool is either a generic tool name or a DRPath string
|
||||
next_tool_name = state.tools_used[-1]
|
||||
|
||||
available_tools = state.available_tools
|
||||
if not available_tools:
|
||||
raise ValueError("No tool is available. This should not happen.")
|
||||
|
||||
if next_tool_name in available_tools:
|
||||
next_tool_path = available_tools[next_tool_name].path
|
||||
elif next_tool_name == DRPath.END.value:
|
||||
return END
|
||||
elif next_tool_name == DRPath.LOGGER.value:
|
||||
return DRPath.LOGGER
|
||||
elif next_tool_name == DRPath.CLOSER.value:
|
||||
return DRPath.CLOSER
|
||||
else:
|
||||
return DRPath.ORCHESTRATOR
|
||||
|
||||
# handle invalid paths
|
||||
if next_tool_path == DRPath.CLARIFIER:
|
||||
raise ValueError("CLARIFIER is not a valid path during iteration")
|
||||
|
||||
# handle tool calls without a query
|
||||
if (
|
||||
next_tool_path
|
||||
in (
|
||||
DRPath.INTERNAL_SEARCH,
|
||||
DRPath.WEB_SEARCH,
|
||||
DRPath.KNOWLEDGE_GRAPH,
|
||||
DRPath.IMAGE_GENERATION,
|
||||
)
|
||||
and len(state.query_list) == 0
|
||||
):
|
||||
return DRPath.CLOSER
|
||||
|
||||
return next_tool_path
|
||||
|
||||
|
||||
def completeness_router(state: MainState) -> DRPath | str:
|
||||
if not state.tools_used:
|
||||
raise IndexError("tools_used cannot be empty")
|
||||
|
||||
# go to closer if path is CLOSER or no queries
|
||||
next_path = state.tools_used[-1]
|
||||
|
||||
if next_path == DRPath.ORCHESTRATOR.value:
|
||||
return DRPath.ORCHESTRATOR
|
||||
return DRPath.LOGGER
|
||||
@@ -1,31 +0,0 @@
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
|
||||
MAX_CHAT_HISTORY_MESSAGES = (
|
||||
3 # note: actual count is x2 to account for user and assistant messages
|
||||
)
|
||||
|
||||
MAX_DR_PARALLEL_SEARCH = 4
|
||||
|
||||
# TODO: test more, generally not needed/adds unnecessary iterations
|
||||
MAX_NUM_CLOSER_SUGGESTIONS = (
|
||||
0 # how many times the closer can send back to the orchestrator
|
||||
)
|
||||
|
||||
CLARIFICATION_REQUEST_PREFIX = "PLEASE CLARIFY:"
|
||||
HIGH_LEVEL_PLAN_PREFIX = "The Plan:"
|
||||
|
||||
AVERAGE_TOOL_COSTS: dict[DRPath, float] = {
|
||||
DRPath.INTERNAL_SEARCH: 1.0,
|
||||
DRPath.KNOWLEDGE_GRAPH: 2.0,
|
||||
DRPath.WEB_SEARCH: 1.5,
|
||||
DRPath.IMAGE_GENERATION: 3.0,
|
||||
DRPath.GENERIC_TOOL: 1.5, # TODO: see todo in OrchestratorTool
|
||||
DRPath.CLOSER: 0.0,
|
||||
}
|
||||
|
||||
DR_TIME_BUDGET_BY_TYPE = {
|
||||
ResearchType.THOUGHTFUL: 3.0,
|
||||
ResearchType.DEEP: 12.0,
|
||||
ResearchType.FAST: 0.5,
|
||||
}
|
||||
@@ -1,112 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.prompts.dr_prompts import GET_CLARIFICATION_PROMPT
|
||||
from onyx.prompts.dr_prompts import KG_TYPES_DESCRIPTIONS
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
from onyx.prompts.dr_prompts import ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DIFFERENTIATION_HINTS
|
||||
from onyx.prompts.dr_prompts import TOOL_QUESTION_HINTS
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
|
||||
|
||||
def get_dr_prompt_orchestration_templates(
|
||||
purpose: DRPromptPurpose,
|
||||
research_type: ResearchType,
|
||||
available_tools: dict[str, OrchestratorTool],
|
||||
entity_types_string: str | None = None,
|
||||
relationship_types_string: str | None = None,
|
||||
reasoning_result: str | None = None,
|
||||
tool_calls_string: str | None = None,
|
||||
) -> PromptTemplate:
|
||||
available_tools = available_tools or {}
|
||||
tool_names = list(available_tools.keys())
|
||||
tool_description_str = "\n\n".join(
|
||||
f"- {tool_name}: {tool.description}"
|
||||
for tool_name, tool in available_tools.items()
|
||||
)
|
||||
tool_cost_str = "\n".join(
|
||||
f"{tool_name}: {tool.cost}" for tool_name, tool in available_tools.items()
|
||||
)
|
||||
|
||||
tool_differentiations: list[str] = [
|
||||
TOOL_DIFFERENTIATION_HINTS[(tool_1, tool_2)]
|
||||
for tool_1 in available_tools
|
||||
for tool_2 in available_tools
|
||||
if (tool_1, tool_2) in TOOL_DIFFERENTIATION_HINTS
|
||||
]
|
||||
tool_differentiation_hint_string = (
|
||||
"\n".join(tool_differentiations) or "(No differentiating hints available)"
|
||||
)
|
||||
# TODO: add tool deliniation pairs for custom tools as well
|
||||
|
||||
tool_question_hint_string = (
|
||||
"\n".join(
|
||||
"- " + TOOL_QUESTION_HINTS[tool]
|
||||
for tool in available_tools
|
||||
if tool in TOOL_QUESTION_HINTS
|
||||
)
|
||||
or "(No examples available)"
|
||||
)
|
||||
|
||||
if DRPath.KNOWLEDGE_GRAPH.value in available_tools and (
|
||||
entity_types_string or relationship_types_string
|
||||
):
|
||||
|
||||
kg_types_descriptions = KG_TYPES_DESCRIPTIONS.build(
|
||||
possible_entities=entity_types_string or "",
|
||||
possible_relationships=relationship_types_string or "",
|
||||
)
|
||||
else:
|
||||
kg_types_descriptions = "(The Knowledge Graph is not used.)"
|
||||
|
||||
if purpose == DRPromptPurpose.PLAN:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("plan generation is not supported for FAST time budget")
|
||||
base_template = ORCHESTRATOR_DEEP_INITIAL_PLAN_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_REASONING:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_REASONING_PROMPT
|
||||
else:
|
||||
raise ValueError(
|
||||
"reasoning is not separately required for DEEP time budget"
|
||||
)
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP_PURPOSE:
|
||||
base_template = ORCHESTRATOR_NEXT_STEP_PURPOSE_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.NEXT_STEP:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
base_template = ORCHESTRATOR_FAST_ITERATIVE_DECISION_PROMPT
|
||||
else:
|
||||
base_template = ORCHESTRATOR_DEEP_ITERATIVE_DECISION_PROMPT
|
||||
|
||||
elif purpose == DRPromptPurpose.CLARIFICATION:
|
||||
if research_type == ResearchType.THOUGHTFUL:
|
||||
raise ValueError("clarification is not supported for FAST time budget")
|
||||
base_template = GET_CLARIFICATION_PROMPT
|
||||
|
||||
else:
|
||||
# for mypy, clearly a mypy bug
|
||||
raise ValueError(f"Invalid purpose: {purpose}")
|
||||
|
||||
return base_template.partial_build(
|
||||
num_available_tools=str(len(tool_names)),
|
||||
available_tools=", ".join(tool_names),
|
||||
tool_choice_options=" or ".join(tool_names),
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
kg_types_descriptions=kg_types_descriptions,
|
||||
tool_descriptions=tool_description_str,
|
||||
tool_differentiation_hints=tool_differentiation_hint_string,
|
||||
tool_question_hints=tool_question_hint_string,
|
||||
average_tool_costs=tool_cost_str,
|
||||
reasoning_result=reasoning_result or "(No reasoning result provided.)",
|
||||
tool_calls_string=tool_calls_string or "(No tool calls provided.)",
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ResearchType(str, Enum):
|
||||
"""Research type options for agent search operations"""
|
||||
|
||||
# BASIC = "BASIC"
|
||||
LEGACY_AGENTIC = "LEGACY_AGENTIC" # only used for legacy agentic search migrations
|
||||
THOUGHTFUL = "THOUGHTFUL"
|
||||
DEEP = "DEEP"
|
||||
FAST = "FAST"
|
||||
|
||||
|
||||
class ResearchAnswerPurpose(str, Enum):
|
||||
"""Research answer purpose options for agent search operations"""
|
||||
|
||||
ANSWER = "ANSWER"
|
||||
CLARIFICATION_REQUEST = "CLARIFICATION_REQUEST"
|
||||
|
||||
|
||||
class DRPath(str, Enum):
|
||||
CLARIFIER = "Clarifier"
|
||||
ORCHESTRATOR = "Orchestrator"
|
||||
INTERNAL_SEARCH = "Internal Search"
|
||||
GENERIC_TOOL = "Generic Tool"
|
||||
KNOWLEDGE_GRAPH = "Knowledge Graph Search"
|
||||
WEB_SEARCH = "Web Search"
|
||||
IMAGE_GENERATION = "Image Generation"
|
||||
GENERIC_INTERNAL_TOOL = "Generic Internal Tool"
|
||||
CLOSER = "Closer"
|
||||
LOGGER = "Logger"
|
||||
END = "End"
|
||||
@@ -1,88 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.conditional_edges import completeness_router
|
||||
from onyx.agents.agent_search.dr.conditional_edges import decision_router
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a0_clarification import clarifier
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a1_orchestrator import orchestrator
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a2_closer import closer
|
||||
from onyx.agents.agent_search.dr.nodes.dr_a3_logger import logging
|
||||
from onyx.agents.agent_search.dr.states import MainInput
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_graph_builder import (
|
||||
dr_basic_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_graph_builder import (
|
||||
dr_custom_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_graph_builder import (
|
||||
dr_generic_internal_tool_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_graph_builder import (
|
||||
dr_image_generation_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.kg_search.dr_kg_search_graph_builder import (
|
||||
dr_kg_search_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.web_search.dr_ws_graph_builder import (
|
||||
dr_ws_graph_builder,
|
||||
)
|
||||
|
||||
# from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import search
|
||||
|
||||
|
||||
def dr_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for the deep research agent.
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=MainState, input=MainInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(DRPath.CLARIFIER, clarifier)
|
||||
|
||||
graph.add_node(DRPath.ORCHESTRATOR, orchestrator)
|
||||
|
||||
basic_search_graph = dr_basic_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.INTERNAL_SEARCH, basic_search_graph)
|
||||
|
||||
kg_search_graph = dr_kg_search_graph_builder().compile()
|
||||
graph.add_node(DRPath.KNOWLEDGE_GRAPH, kg_search_graph)
|
||||
|
||||
internet_search_graph = dr_ws_graph_builder().compile()
|
||||
graph.add_node(DRPath.WEB_SEARCH, internet_search_graph)
|
||||
|
||||
image_generation_graph = dr_image_generation_graph_builder().compile()
|
||||
graph.add_node(DRPath.IMAGE_GENERATION, image_generation_graph)
|
||||
|
||||
custom_tool_graph = dr_custom_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_TOOL, custom_tool_graph)
|
||||
|
||||
generic_internal_tool_graph = dr_generic_internal_tool_graph_builder().compile()
|
||||
graph.add_node(DRPath.GENERIC_INTERNAL_TOOL, generic_internal_tool_graph)
|
||||
|
||||
graph.add_node(DRPath.CLOSER, closer)
|
||||
graph.add_node(DRPath.LOGGER, logging)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key=DRPath.CLARIFIER)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLARIFIER, decision_router)
|
||||
|
||||
graph.add_conditional_edges(DRPath.ORCHESTRATOR, decision_router)
|
||||
|
||||
graph.add_edge(start_key=DRPath.INTERNAL_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.KNOWLEDGE_GRAPH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.WEB_SEARCH, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.IMAGE_GENERATION, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
graph.add_edge(start_key=DRPath.GENERIC_INTERNAL_TOOL, end_key=DRPath.ORCHESTRATOR)
|
||||
|
||||
graph.add_conditional_edges(DRPath.CLOSER, completeness_router)
|
||||
graph.add_edge(start_key=DRPath.LOGGER, end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,131 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImage,
|
||||
)
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
class OrchestratorStep(BaseModel):
|
||||
tool: str
|
||||
questions: list[str]
|
||||
|
||||
|
||||
class OrchestratorDecisonsNoPlan(BaseModel):
|
||||
reasoning: str
|
||||
next_step: OrchestratorStep
|
||||
|
||||
|
||||
class OrchestrationPlan(BaseModel):
|
||||
reasoning: str
|
||||
plan: str
|
||||
|
||||
|
||||
class ClarificationGenerationResponse(BaseModel):
|
||||
clarification_needed: bool
|
||||
clarification_question: str
|
||||
|
||||
|
||||
class DecisionResponse(BaseModel):
|
||||
reasoning: str
|
||||
decision: str
|
||||
|
||||
|
||||
class QueryEvaluationResponse(BaseModel):
|
||||
reasoning: str
|
||||
query_permitted: bool
|
||||
|
||||
|
||||
class OrchestrationClarificationInfo(BaseModel):
|
||||
clarification_question: str
|
||||
clarification_response: str | None = None
|
||||
|
||||
|
||||
class WebSearchAnswer(BaseModel):
|
||||
urls_to_open_indices: list[int]
|
||||
|
||||
|
||||
class SearchAnswer(BaseModel):
|
||||
reasoning: str
|
||||
answer: str
|
||||
claims: list[str] | None = None
|
||||
|
||||
|
||||
class TestInfoCompleteResponse(BaseModel):
|
||||
reasoning: str
|
||||
complete: bool
|
||||
gaps: list[str]
|
||||
|
||||
|
||||
# TODO: revisit with custom tools implementation in v2
|
||||
# each tool should be a class with the attributes below, plus the actual tool implementation
|
||||
# this will also allow custom tools to have their own cost
|
||||
class OrchestratorTool(BaseModel):
|
||||
tool_id: int
|
||||
name: str
|
||||
llm_path: str # the path for the LLM to refer by
|
||||
path: DRPath # the actual path in the graph
|
||||
description: str
|
||||
metadata: dict[str, str]
|
||||
cost: float
|
||||
tool_object: Tool | None = None # None for CLOSER
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class IterationInstructions(BaseModel):
|
||||
iteration_nr: int
|
||||
plan: str | None
|
||||
reasoning: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class IterationAnswer(BaseModel):
|
||||
tool: str
|
||||
tool_id: int
|
||||
iteration_nr: int
|
||||
parallelization_nr: int
|
||||
question: str
|
||||
reasoning: str | None
|
||||
answer: str
|
||||
cited_documents: dict[int, InferenceSection]
|
||||
background_info: str | None = None
|
||||
claims: list[str] | None = None
|
||||
additional_data: dict[str, str] | None = None
|
||||
response_type: str | None = None
|
||||
data: dict | list | str | int | float | bool | None = None
|
||||
file_ids: list[str] | None = None
|
||||
# TODO: This is not ideal, but we'll can rework the schema
|
||||
# for deep research later
|
||||
is_web_fetch: bool = False
|
||||
# for image generation step-types
|
||||
generated_images: list[GeneratedImage] | None = None
|
||||
# for multi-query search tools (v2 web search and internal search)
|
||||
# TODO: Clean this up to be more flexible to tools
|
||||
queries: list[str] | None = None
|
||||
|
||||
|
||||
class AggregatedDRContext(BaseModel):
|
||||
context: str
|
||||
cited_documents: list[InferenceSection]
|
||||
is_internet_marker_dict: dict[str, bool]
|
||||
global_iteration_responses: list[IterationAnswer]
|
||||
|
||||
|
||||
class DRPromptPurpose(str, Enum):
|
||||
PLAN = "PLAN"
|
||||
NEXT_STEP = "NEXT_STEP"
|
||||
NEXT_STEP_REASONING = "NEXT_STEP_REASONING"
|
||||
NEXT_STEP_PURPOSE = "NEXT_STEP_PURPOSE"
|
||||
CLARIFICATION = "CLARIFICATION"
|
||||
|
||||
|
||||
class BaseSearchProcessingResponse(BaseModel):
|
||||
specified_source_types: list[str]
|
||||
rewritten_query: str
|
||||
time_filter: str
|
||||
@@ -1,918 +0,0 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from braintrust import traced
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import AVERAGE_TOOL_COSTS
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import ClarificationGenerationResponse
|
||||
from onyx.agents.agent_search.dr.models import DecisionResponse
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import (
|
||||
BasicSearchProcessedStreamResults,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.process_llm_stream import process_llm_stream
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationSetup
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import build_citation_map_from_numbers
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.memories import get_memories
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_system_message
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from onyx.chat.stream_processing.citation_processing import (
|
||||
normalize_square_bracket_citations_to_double_with_links,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceDescription
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.db.chat import create_search_doc_from_saved_search_doc
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.connector import fetch_unique_document_sources
|
||||
from onyx.db.kg_config import get_kg_config_settings
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.tools import get_tools
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import ANSWER_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_W_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DECISION_PROMPT_WO_TOOL_CALLING
|
||||
from onyx.prompts.dr_prompts import DEFAULT_DR_SYSTEM_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import TOOL_DESCRIPTION
|
||||
from onyx.prompts.prompt_template import PromptTemplate
|
||||
from onyx.prompts.prompt_utils import handle_company_awareness
|
||||
from onyx.prompts.prompt_utils import handle_memories
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.knowledge_graph.knowledge_graph_tool import (
|
||||
KnowledgeGraphTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import (
|
||||
WebSearchTool,
|
||||
)
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_available_tools(
|
||||
db_session: Session,
|
||||
graph_config: GraphConfig,
|
||||
kg_enabled: bool,
|
||||
active_source_types: list[DocumentSource],
|
||||
) -> dict[str, OrchestratorTool]:
|
||||
|
||||
available_tools: dict[str, OrchestratorTool] = {}
|
||||
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
persona = graph_config.inputs.persona
|
||||
|
||||
if persona:
|
||||
include_kg = persona.name == TMP_DRALPHA_PERSONA_NAME and kg_enabled
|
||||
else:
|
||||
include_kg = False
|
||||
|
||||
tool_dict: dict[int, Tool] = {
|
||||
tool.id: tool for tool in get_tools(db_session, only_enabled=True)
|
||||
}
|
||||
|
||||
for tool in graph_config.tooling.tools:
|
||||
|
||||
if not tool.is_available(db_session):
|
||||
logger.info(f"Tool {tool.name} is not available, skipping")
|
||||
continue
|
||||
|
||||
tool_db_info = tool_dict.get(tool.id)
|
||||
if tool_db_info:
|
||||
incode_tool_id = tool_db_info.in_code_tool_id
|
||||
else:
|
||||
raise ValueError(f"Tool {tool.name} is not found in the database")
|
||||
|
||||
if isinstance(tool, WebSearchTool):
|
||||
llm_path = DRPath.WEB_SEARCH.value
|
||||
path = DRPath.WEB_SEARCH
|
||||
elif isinstance(tool, SearchTool):
|
||||
llm_path = DRPath.INTERNAL_SEARCH.value
|
||||
path = DRPath.INTERNAL_SEARCH
|
||||
elif isinstance(tool, KnowledgeGraphTool) and include_kg:
|
||||
# TODO (chris): move this into the `is_available` check
|
||||
if len(active_source_types) == 0:
|
||||
logger.error(
|
||||
"No active source types found, skipping Knowledge Graph tool"
|
||||
)
|
||||
continue
|
||||
llm_path = DRPath.KNOWLEDGE_GRAPH.value
|
||||
path = DRPath.KNOWLEDGE_GRAPH
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
llm_path = DRPath.IMAGE_GENERATION.value
|
||||
path = DRPath.IMAGE_GENERATION
|
||||
elif incode_tool_id:
|
||||
# if incode tool id is found, it is a generic internal tool
|
||||
llm_path = DRPath.GENERIC_INTERNAL_TOOL.value
|
||||
path = DRPath.GENERIC_INTERNAL_TOOL
|
||||
else:
|
||||
# otherwise it is a custom tool
|
||||
llm_path = DRPath.GENERIC_TOOL.value
|
||||
path = DRPath.GENERIC_TOOL
|
||||
|
||||
if path not in {DRPath.GENERIC_INTERNAL_TOOL, DRPath.GENERIC_TOOL}:
|
||||
description = TOOL_DESCRIPTION.get(path, tool.description)
|
||||
cost = AVERAGE_TOOL_COSTS[path]
|
||||
else:
|
||||
description = tool.description
|
||||
cost = 1.0
|
||||
|
||||
tool_info = OrchestratorTool(
|
||||
tool_id=tool.id,
|
||||
name=tool.llm_name,
|
||||
llm_path=llm_path,
|
||||
path=path,
|
||||
description=description,
|
||||
metadata={},
|
||||
cost=cost,
|
||||
tool_object=tool,
|
||||
)
|
||||
|
||||
# TODO: handle custom tools with same name as other tools (e.g., CLOSER)
|
||||
available_tools[tool.llm_name] = tool_info
|
||||
|
||||
available_tool_paths = [tool.path for tool in available_tools.values()]
|
||||
|
||||
# make sure KG isn't enabled without internal search
|
||||
if (
|
||||
DRPath.KNOWLEDGE_GRAPH in available_tool_paths
|
||||
and DRPath.INTERNAL_SEARCH not in available_tool_paths
|
||||
):
|
||||
raise ValueError(
|
||||
"The Knowledge Graph is not supported without internal search tool"
|
||||
)
|
||||
|
||||
# add CLOSER tool, which is always available
|
||||
available_tools[DRPath.CLOSER.value] = OrchestratorTool(
|
||||
tool_id=-1,
|
||||
name=DRPath.CLOSER.value,
|
||||
llm_path=DRPath.CLOSER.value,
|
||||
path=DRPath.CLOSER,
|
||||
description=TOOL_DESCRIPTION[DRPath.CLOSER],
|
||||
metadata={},
|
||||
cost=0.0,
|
||||
tool_object=None,
|
||||
)
|
||||
|
||||
return available_tools
|
||||
|
||||
|
||||
def _construct_uploaded_text_context(files: list[InMemoryChatFile]) -> str:
|
||||
"""Construct the uploaded context from the files."""
|
||||
file_contents = []
|
||||
for file in files:
|
||||
if file.file_type in (
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.CSV,
|
||||
):
|
||||
file_contents.append(file.content.decode("utf-8"))
|
||||
if len(file_contents) > 0:
|
||||
return "Uploaded context:\n\n\n" + "\n\n".join(file_contents)
|
||||
return ""
|
||||
|
||||
|
||||
def _construct_uploaded_image_context(
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Construct the uploaded image context from the files."""
|
||||
# Only include image files for user messages
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
img_files = [file for file in files if file.file_type == ChatFileType.IMAGE]
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
if not (img_files or img_urls or b64_imgs):
|
||||
return None
|
||||
|
||||
return cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
},
|
||||
}
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url in img_urls
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _get_existing_clarification_request(
|
||||
graph_config: GraphConfig,
|
||||
) -> tuple[OrchestrationClarificationInfo, str, str] | None:
|
||||
"""
|
||||
Returns the clarification info, original question, and updated chat history if
|
||||
a clarification request and response exists, otherwise returns None.
|
||||
"""
|
||||
# check for clarification request and response in message history
|
||||
previous_raw_messages = graph_config.inputs.prompt_builder.raw_message_history
|
||||
|
||||
if len(previous_raw_messages) == 0 or (
|
||||
previous_raw_messages[-1].research_answer_purpose
|
||||
!= ResearchAnswerPurpose.CLARIFICATION_REQUEST
|
||||
):
|
||||
return None
|
||||
|
||||
# get the clarification request and response
|
||||
previous_messages = graph_config.inputs.prompt_builder.message_history
|
||||
last_message = previous_raw_messages[-1].message
|
||||
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=last_message.strip(),
|
||||
clarification_response=graph_config.inputs.prompt_builder.raw_user_query,
|
||||
)
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
chat_history_string = "(No chat history yet available)"
|
||||
|
||||
# get the original user query and chat history string before the original query
|
||||
# e.g., if history = [user query, assistant clarification request, user clarification response],
|
||||
# previous_messages = [user query, assistant clarification request], we want the user query
|
||||
for i, message in enumerate(reversed(previous_messages), 1):
|
||||
if (
|
||||
isinstance(message, HumanMessage)
|
||||
and message.content
|
||||
and isinstance(message.content, str)
|
||||
):
|
||||
original_question = message.content
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history[:-i],
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
break
|
||||
|
||||
return clarification, original_question, chat_history_string
|
||||
|
||||
|
||||
def _persist_final_docs_and_citations(
|
||||
db_session: Session,
|
||||
context_llm_docs: list[Any] | None,
|
||||
full_answer: str | None,
|
||||
) -> tuple[list[SearchDoc], dict[int, int] | None]:
|
||||
"""Persist final documents from in-context docs and derive citation mapping.
|
||||
|
||||
Returns the list of persisted `SearchDoc` records and an optional
|
||||
citation map translating inline [[n]] references to DB doc indices.
|
||||
"""
|
||||
final_documents_db: list[SearchDoc] = []
|
||||
citations_map: dict[int, int] | None = None
|
||||
|
||||
if not context_llm_docs:
|
||||
return final_documents_db, citations_map
|
||||
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(context_llm_docs)
|
||||
for saved_doc in saved_search_docs:
|
||||
db_doc = create_search_doc_from_saved_search_doc(saved_doc)
|
||||
db_session.add(db_doc)
|
||||
final_documents_db.append(db_doc)
|
||||
db_session.flush()
|
||||
|
||||
cited_numbers: set[int] = set()
|
||||
try:
|
||||
# Match [[1]] or [[1, 2]] optionally followed by a link like ([[1]](http...))
|
||||
matches = re.findall(
|
||||
r"\[\[(\d+(?:,\s*\d+)*)\]\](?:\([^)]*\))?", full_answer or ""
|
||||
)
|
||||
for match in matches:
|
||||
for num_str in match.split(","):
|
||||
num = int(num_str.strip())
|
||||
cited_numbers.add(num)
|
||||
except Exception:
|
||||
cited_numbers = set()
|
||||
|
||||
if cited_numbers and final_documents_db:
|
||||
translations = build_citation_map_from_numbers(
|
||||
cited_numbers=cited_numbers,
|
||||
db_docs=final_documents_db,
|
||||
)
|
||||
citations_map = translations or None
|
||||
|
||||
return final_documents_db, citations_map
|
||||
|
||||
|
||||
_ARTIFICIAL_ALL_ENCOMPASSING_TOOL = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "run_any_knowledge_retrieval_and_any_action_tool",
|
||||
"description": "Use this tool to get ANY external information \
|
||||
that is relevant to the question, or for any action to be taken, including image generation. In fact, \
|
||||
ANY tool mentioned can be accessed through this generic tool. If in doubt, use this tool.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"request": {
|
||||
"type": "string",
|
||||
"description": "The request to be made to the tool",
|
||||
},
|
||||
},
|
||||
"required": ["request"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def clarifier(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationSetup:
|
||||
"""
|
||||
Perform a quick search on the question as is and see whether a set of clarification
|
||||
questions is needed. For now this is based on the models
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
current_step_nr = 0
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
|
||||
max_input_tokens = get_max_input_tokens(
|
||||
model_name=llm_model_name,
|
||||
model_provider=llm_provider,
|
||||
)
|
||||
|
||||
use_tool_calling_llm = graph_config.tooling.using_tool_calling_llm
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
original_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
|
||||
message_id = graph_config.persistence.message_id
|
||||
|
||||
# Perform a commit to ensure the message_id is set and saved
|
||||
db_session.commit()
|
||||
|
||||
# get the connected tools and format for the Deep Research flow
|
||||
kg_enabled = graph_config.behavior.kg_config_settings.KG_ENABLED
|
||||
db_session = graph_config.persistence.db_session
|
||||
active_source_types = fetch_unique_document_sources(db_session)
|
||||
|
||||
available_tools = _get_available_tools(
|
||||
db_session, graph_config, kg_enabled, active_source_types
|
||||
)
|
||||
|
||||
available_tool_descriptions_str = "\n -" + "\n -".join(
|
||||
[tool.description for tool in available_tools.values()]
|
||||
)
|
||||
|
||||
kg_config = get_kg_config_settings()
|
||||
if kg_config.KG_ENABLED and kg_config.KG_EXPOSED:
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
else:
|
||||
all_entity_types = ""
|
||||
all_relationship_types = ""
|
||||
|
||||
# if not active_source_types:
|
||||
# raise ValueError("No active source types found")
|
||||
|
||||
active_source_types_descriptions = [
|
||||
DocumentSourceDescription[source_type] for source_type in active_source_types
|
||||
]
|
||||
|
||||
if len(active_source_types_descriptions) > 0:
|
||||
active_source_type_descriptions_str = "\n -" + "\n -".join(
|
||||
active_source_types_descriptions
|
||||
)
|
||||
else:
|
||||
active_source_type_descriptions_str = ""
|
||||
|
||||
if graph_config.inputs.persona:
|
||||
assistant_system_prompt = PromptTemplate(
|
||||
graph_config.inputs.persona.system_prompt or DEFAULT_DR_SYSTEM_PROMPT
|
||||
).build()
|
||||
if graph_config.inputs.persona.task_prompt:
|
||||
assistant_task_prompt = (
|
||||
"\n\nHere are more specifications from the user:\n\n"
|
||||
+ PromptTemplate(graph_config.inputs.persona.task_prompt).build()
|
||||
)
|
||||
else:
|
||||
assistant_task_prompt = ""
|
||||
|
||||
else:
|
||||
assistant_system_prompt = PromptTemplate(DEFAULT_DR_SYSTEM_PROMPT).build()
|
||||
assistant_task_prompt = ""
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
user = (
|
||||
graph_config.tooling.search_tool.user
|
||||
if graph_config.tooling.search_tool
|
||||
else None
|
||||
)
|
||||
memories = get_memories(user, db_session)
|
||||
assistant_system_prompt = handle_company_awareness(assistant_system_prompt)
|
||||
assistant_system_prompt = handle_memories(assistant_system_prompt, memories)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
uploaded_text_context = (
|
||||
_construct_uploaded_text_context(graph_config.inputs.files)
|
||||
if graph_config.inputs.files
|
||||
else ""
|
||||
)
|
||||
|
||||
uploaded_context_tokens = check_number_of_tokens(
|
||||
uploaded_text_context, llm_tokenizer.encode
|
||||
)
|
||||
|
||||
if uploaded_context_tokens > 0.5 * max_input_tokens:
|
||||
raise ValueError(
|
||||
f"Uploaded context is too long. {uploaded_context_tokens} tokens, "
|
||||
f"but for this model we only allow {0.5 * max_input_tokens} tokens for uploaded context"
|
||||
)
|
||||
|
||||
uploaded_image_context = _construct_uploaded_image_context(
|
||||
graph_config.inputs.files
|
||||
)
|
||||
|
||||
# Use project/search context docs if available to enable citation mapping
|
||||
context_llm_docs = getattr(
|
||||
graph_config.inputs.prompt_builder, "context_llm_docs", None
|
||||
)
|
||||
|
||||
if not (force_use_tool and force_use_tool.force_use):
|
||||
|
||||
if not use_tool_calling_llm or len(available_tools) == 1:
|
||||
if len(available_tools) > 1:
|
||||
decision_prompt = DECISION_PROMPT_WO_TOOL_CALLING.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_text_context or "",
|
||||
active_source_type_descriptions_str=active_source_type_descriptions_str,
|
||||
available_tool_descriptions_str=available_tool_descriptions_str,
|
||||
)
|
||||
|
||||
llm_decision = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=DecisionResponse,
|
||||
)
|
||||
else:
|
||||
# if there is only one tool (Closer), we don't need to decide. It's an LLM answer
|
||||
llm_decision = DecisionResponse(decision="LLM", reasoning="")
|
||||
|
||||
if llm_decision.decision == "LLM" and research_type != ResearchType.DEEP:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
answer_prompt = ANSWER_PROMPT_WO_TOOL_CALLING.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_text_context or "",
|
||||
active_source_type_descriptions_str=active_source_type_descriptions_str,
|
||||
available_tool_descriptions_str=available_tool_descriptions_str,
|
||||
)
|
||||
|
||||
answer_tokens, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
answer_prompt + assistant_task_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
ind=current_step_nr,
|
||||
context_docs=None,
|
||||
replace_citations=True,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
current_step_nr += 1
|
||||
|
||||
answer_str = cast(str, merge_content(*answer_tokens))
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
OverallStop(),
|
||||
writer,
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=answer_str,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
available_tools=available_tools,
|
||||
query_list=[],
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
decision_prompt = DECISION_PROMPT_W_TOOL_CALLING.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_text_context or "",
|
||||
active_source_type_descriptions_str=active_source_type_descriptions_str,
|
||||
)
|
||||
|
||||
if context_llm_docs:
|
||||
persona = graph_config.inputs.persona
|
||||
if persona is not None:
|
||||
prompt_config = PromptConfig.from_model(
|
||||
persona, db_session=graph_config.persistence.db_session
|
||||
)
|
||||
else:
|
||||
prompt_config = PromptConfig(
|
||||
default_behavior_system_prompt=assistant_system_prompt,
|
||||
custom_instructions=None,
|
||||
reminder="",
|
||||
datetime_aware=True,
|
||||
)
|
||||
|
||||
system_prompt_to_use_content = build_citations_system_message(
|
||||
prompt_config
|
||||
).content
|
||||
system_prompt_to_use: str = cast(str, system_prompt_to_use_content)
|
||||
if graph_config.inputs.project_instructions:
|
||||
system_prompt_to_use = (
|
||||
system_prompt_to_use
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ graph_config.inputs.project_instructions
|
||||
)
|
||||
user_prompt_to_use = build_citations_user_message(
|
||||
user_query=original_question,
|
||||
files=[],
|
||||
prompt_config=prompt_config,
|
||||
context_docs=context_llm_docs,
|
||||
all_doc_useful=False,
|
||||
history_message=chat_history_string,
|
||||
context_type="user files",
|
||||
).content
|
||||
else:
|
||||
system_prompt_to_use = assistant_system_prompt
|
||||
user_prompt_to_use = decision_prompt + assistant_task_prompt
|
||||
|
||||
@traced(name="clarifier stream and process", type="llm")
|
||||
def stream_and_process() -> BasicSearchProcessedStreamResults:
|
||||
stream = graph_config.tooling.primary_llm.stream_langchain(
|
||||
prompt=create_question_prompt(
|
||||
cast(str, system_prompt_to_use),
|
||||
cast(str, user_prompt_to_use),
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
tools=([_ARTIFICIAL_ALL_ENCOMPASSING_TOOL]),
|
||||
tool_choice=(None),
|
||||
structured_response_format=graph_config.inputs.structured_response_format,
|
||||
)
|
||||
return process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=True,
|
||||
writer=writer,
|
||||
ind=0,
|
||||
search_results=context_llm_docs,
|
||||
generate_final_answer=True,
|
||||
chat_message_id=str(graph_config.persistence.chat_session_id),
|
||||
)
|
||||
|
||||
# Deep research always continues to clarification or search
|
||||
if research_type != ResearchType.DEEP:
|
||||
full_response = stream_and_process()
|
||||
if len(full_response.ai_message_chunk.tool_calls) == 0:
|
||||
|
||||
if isinstance(full_response.full_answer, str):
|
||||
full_answer = (
|
||||
normalize_square_bracket_citations_to_double_with_links(
|
||||
full_response.full_answer
|
||||
)
|
||||
)
|
||||
else:
|
||||
full_answer = None
|
||||
|
||||
# Persist final documents and derive citations when using in-context docs
|
||||
final_documents_db, citations_map = (
|
||||
_persist_final_docs_and_citations(
|
||||
db_session=db_session,
|
||||
context_llm_docs=context_llm_docs,
|
||||
full_answer=full_answer,
|
||||
)
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=full_answer,
|
||||
token_count=len(llm_tokenizer.encode(full_answer or "")),
|
||||
citations=citations_map,
|
||||
final_documents=final_documents_db or None,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string="",
|
||||
tools_used=[DRPath.END.value],
|
||||
query_list=[],
|
||||
available_tools=available_tools,
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
)
|
||||
|
||||
# Continue, as external knowledge is required.
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
clarification = None
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
result = _get_existing_clarification_request(graph_config)
|
||||
if result is not None:
|
||||
clarification, original_question, chat_history_string = result
|
||||
else:
|
||||
# generate clarification questions if needed
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
base_clarification_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.CLARIFICATION,
|
||||
research_type,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
clarification_prompt = base_clarification_prompt.build(
|
||||
question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
)
|
||||
|
||||
try:
|
||||
clarification_response = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
clarification_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=ClarificationGenerationResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in clarification generation: {e}")
|
||||
raise e
|
||||
|
||||
if (
|
||||
clarification_response.clarification_needed
|
||||
and clarification_response.clarification_question
|
||||
):
|
||||
clarification = OrchestrationClarificationInfo(
|
||||
clarification_question=clarification_response.clarification_question,
|
||||
clarification_response=None,
|
||||
)
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=None,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
repeat_prompt = REPEAT_PROMPT.build(
|
||||
original_information=clarification_response.clarification_question
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(
|
||||
type="section_end",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
1,
|
||||
OverallStop(),
|
||||
writer,
|
||||
)
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification_response.clarification_question,
|
||||
update_parent_message=True,
|
||||
research_type=research_type,
|
||||
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
else:
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
if (
|
||||
clarification
|
||||
and clarification.clarification_question
|
||||
and clarification.clarification_response is None
|
||||
):
|
||||
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=clarification.clarification_question,
|
||||
update_parent_message=True,
|
||||
research_type=research_type,
|
||||
research_answer_purpose=ResearchAnswerPurpose.CLARIFICATION_REQUEST,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
next_tool = DRPath.END.value
|
||||
else:
|
||||
next_tool = DRPath.ORCHESTRATOR.value
|
||||
|
||||
return OrchestrationSetup(
|
||||
original_question=original_question,
|
||||
chat_history_string=chat_history_string,
|
||||
tools_used=[next_tool],
|
||||
query_list=[],
|
||||
iteration_nr=0,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="clarifier",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
clarification=clarification,
|
||||
available_tools=available_tools,
|
||||
active_source_types=active_source_types,
|
||||
active_source_types_descriptions="\n".join(active_source_types_descriptions),
|
||||
assistant_system_prompt=assistant_system_prompt,
|
||||
assistant_task_prompt=assistant_task_prompt,
|
||||
uploaded_test_context=uploaded_text_context,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
)
|
||||
@@ -1,624 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_content
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import DR_TIME_BUDGET_BY_TYPE
|
||||
from onyx.agents.agent_search.dr.constants import HIGH_LEVEL_PLAN_PREFIX
|
||||
from onyx.agents.agent_search.dr.dr_prompt_builder import (
|
||||
get_dr_prompt_orchestration_templates,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import DRPromptPurpose
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorDecisonsNoPlan
|
||||
from onyx.agents.agent_search.dr.states import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import create_tool_call_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import run_with_timeout
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.kg.utils.extraction_utils import get_entity_types_str
|
||||
from onyx.kg.utils.extraction_utils import get_relationship_types_str
|
||||
from onyx.prompts.dr_prompts import DEFAULLT_DECISION_PROMPT
|
||||
from onyx.prompts.dr_prompts import REPEAT_PROMPT
|
||||
from onyx.prompts.dr_prompts import SUFFICIENT_INFORMATION_STRING
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_DECISION_SYSTEM_PROMPT_PREFIX = "Here are general instructions by the user, which \
|
||||
may or may not influence the decision what to do next:\n\n"
|
||||
|
||||
|
||||
def _get_implied_next_tool_based_on_tool_call_history(
|
||||
tools_used: list[str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Identify the next tool based on the tool call history. Initially, we only support
|
||||
special handling of the image generation tool.
|
||||
"""
|
||||
if tools_used[-1] == DRPath.IMAGE_GENERATION.value:
|
||||
return DRPath.LOGGER.value
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def orchestrator(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to decide the next step in the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = state.original_question
|
||||
if not question:
|
||||
raise ValueError("Question is required for orchestrator")
|
||||
|
||||
state.original_question
|
||||
|
||||
available_tools = state.available_tools
|
||||
|
||||
plan_of_record = state.plan_of_record
|
||||
clarification = state.clarification
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
|
||||
if assistant_system_prompt:
|
||||
decision_system_prompt: str = (
|
||||
DEFAULLT_DECISION_PROMPT
|
||||
+ _DECISION_SYSTEM_PROMPT_PREFIX
|
||||
+ assistant_system_prompt
|
||||
)
|
||||
else:
|
||||
decision_system_prompt = DEFAULLT_DECISION_PROMPT
|
||||
|
||||
iteration_nr = state.iteration_nr + 1
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
remaining_time_budget = state.remaining_time_budget
|
||||
chat_history_string = state.chat_history_string or "(No chat history yet available)"
|
||||
answer_history_string = (
|
||||
aggregate_context(state.iteration_responses, include_documents=True).context
|
||||
or "(No answer history yet available)"
|
||||
)
|
||||
|
||||
next_tool_name = None
|
||||
|
||||
# Identify early exit condition based on tool call history
|
||||
|
||||
next_tool_based_on_tool_call_history = (
|
||||
_get_implied_next_tool_based_on_tool_call_history(state.tools_used)
|
||||
)
|
||||
|
||||
if next_tool_based_on_tool_call_history == DRPath.LOGGER.value:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.LOGGER.value],
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning="",
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# no early exit forced. Continue.
|
||||
|
||||
available_tools = state.available_tools or {}
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
uploaded_image_context = state.uploaded_image_context or []
|
||||
|
||||
questions = [
|
||||
f"{iteration_response.tool}: {iteration_response.question}"
|
||||
for iteration_response in state.iteration_responses
|
||||
if len(iteration_response.question) > 0
|
||||
]
|
||||
|
||||
question_history_string = (
|
||||
"\n".join(f" - {question}" for question in questions)
|
||||
if questions
|
||||
else "(No question history yet available)"
|
||||
)
|
||||
|
||||
prompt_question = get_prompt_question(question, clarification)
|
||||
|
||||
gaps_str = (
|
||||
("\n - " + "\n - ".join(state.gaps))
|
||||
if state.gaps
|
||||
else "(No explicit gaps were pointed out so far)"
|
||||
)
|
||||
|
||||
all_entity_types = get_entity_types_str(active=True)
|
||||
all_relationship_types = get_relationship_types_str(active=True)
|
||||
|
||||
# default to closer
|
||||
query_list = ["Answer the question with the information you have."]
|
||||
decision_prompt = None
|
||||
|
||||
reasoning_result = "(No reasoning result provided yet.)"
|
||||
tool_calls_string = "(No tool calls provided yet.)"
|
||||
|
||||
if research_type not in ResearchType:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
if iteration_nr == 1:
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[research_type]
|
||||
|
||||
elif remaining_time_budget <= 0:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
current_step_nr=current_step_nr,
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=None,
|
||||
reasoning="Time to wrap up.",
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
elif iteration_nr > 1 and remaining_time_budget > 0:
|
||||
# for each iteration past the first one, we need to see whether we
|
||||
# have enough information to answer the question.
|
||||
# if we do, we can stop the iteration and return the answer.
|
||||
# if we do not, we need to continue the iteration.
|
||||
|
||||
base_reasoning_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_REASONING,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
|
||||
reasoning_prompt = base_reasoning_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
reasoning_tokens: list[str] = [""]
|
||||
|
||||
reasoning_tokens, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
reasoning_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
reasoning_result = cast(str, merge_content(*reasoning_tokens))
|
||||
|
||||
if SUFFICIENT_INFORMATION_STRING in reasoning_result:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.CLOSER.value],
|
||||
current_step_nr=current_step_nr,
|
||||
query_list=[],
|
||||
iteration_nr=iteration_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=None,
|
||||
reasoning=reasoning_result,
|
||||
purpose="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# for Thoughtful mode, we force a tool if requested an available
|
||||
available_tools_for_decision = available_tools
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
if iteration_nr == 1 and force_use_tool and force_use_tool.force_use:
|
||||
|
||||
forced_tool_name = force_use_tool.tool_name
|
||||
|
||||
available_tool_dict = {
|
||||
available_tool.tool_object.name: available_tool
|
||||
for _, available_tool in available_tools.items()
|
||||
if available_tool.tool_object
|
||||
}
|
||||
|
||||
if forced_tool_name in available_tool_dict.keys():
|
||||
forced_tool = available_tool_dict[forced_tool_name]
|
||||
|
||||
available_tools_for_decision = {forced_tool.name: forced_tool}
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.THOUGHTFUL,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools_for_decision,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
question=question,
|
||||
chat_history_string=chat_history_string,
|
||||
answer_history_string=answer_history_string,
|
||||
iteration_nr=str(iteration_nr),
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
reasoning_result=reasoning_result,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=2500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool_name = next_step.tool
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool_name, query_list)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
if next_tool_name in available_tools.keys():
|
||||
remaining_time_budget -= available_tools[next_tool_name].cost
|
||||
else:
|
||||
logger.warning(f"Tool {next_tool_name} not found in available tools")
|
||||
remaining_time_budget -= 1.0
|
||||
|
||||
else:
|
||||
reasoning_result = "Time to wrap up."
|
||||
next_tool_name = DRPath.CLOSER.value
|
||||
|
||||
elif research_type == ResearchType.DEEP:
|
||||
if iteration_nr == 1 and not plan_of_record:
|
||||
# by default, we start a new iteration, but if there is a feedback request,
|
||||
# we start a new iteration 0 again (set a bit later)
|
||||
|
||||
remaining_time_budget = DR_TIME_BUDGET_BY_TYPE[ResearchType.DEEP]
|
||||
|
||||
base_plan_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.PLAN,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
plan_generation_prompt = base_plan_prompt.build(
|
||||
question=prompt_question,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
try:
|
||||
plan_of_record = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
plan_generation_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestrationPlan,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=3000,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in plan generation: {e}")
|
||||
raise
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
|
||||
repeat_plan_prompt = REPEAT_PROMPT.build(
|
||||
original_information=f"{HIGH_LEVEL_PLAN_PREFIX}\n\n {plan_of_record.plan}"
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_plan_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
),
|
||||
)
|
||||
|
||||
end_time = datetime.now()
|
||||
logger.debug(f"Time taken for plan streaming: {end_time - start_time}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
current_step_nr += 1
|
||||
|
||||
if not plan_of_record:
|
||||
raise ValueError(
|
||||
"Plan information is required for iterative decision making"
|
||||
)
|
||||
|
||||
base_decision_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
decision_prompt = base_decision_prompt.build(
|
||||
answer_history_string=answer_history_string,
|
||||
question_history_string=question_history_string,
|
||||
question=prompt_question,
|
||||
iteration_nr=str(iteration_nr),
|
||||
current_plan_of_record_string=plan_of_record.plan,
|
||||
chat_history_string=chat_history_string,
|
||||
remaining_time_budget=str(remaining_time_budget),
|
||||
gaps=gaps_str,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
if remaining_time_budget > 0:
|
||||
try:
|
||||
orchestrator_action = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
decision_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
schema=OrchestratorDecisonsNoPlan,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
next_step = orchestrator_action.next_step
|
||||
next_tool_name = next_step.tool
|
||||
|
||||
query_list = [q for q in (next_step.questions or [])]
|
||||
reasoning_result = orchestrator_action.reasoning
|
||||
|
||||
tool_calls_string = create_tool_call_string(next_tool_name, query_list)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in approach extraction: {e}")
|
||||
raise e
|
||||
|
||||
if next_tool_name in available_tools.keys():
|
||||
remaining_time_budget -= available_tools[next_tool_name].cost
|
||||
else:
|
||||
logger.warning(f"Tool {next_tool_name} not found in available tools")
|
||||
remaining_time_budget -= 1.0
|
||||
else:
|
||||
reasoning_result = "Time to wrap up."
|
||||
next_tool_name = DRPath.CLOSER.value
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
repeat_reasoning_prompt = REPEAT_PROMPT.build(
|
||||
original_information=reasoning_result
|
||||
)
|
||||
|
||||
_, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=repeat_reasoning_prompt,
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Research type {research_type} is not implemented.")
|
||||
|
||||
base_next_step_purpose_prompt = get_dr_prompt_orchestration_templates(
|
||||
DRPromptPurpose.NEXT_STEP_PURPOSE,
|
||||
ResearchType.DEEP,
|
||||
entity_types_string=all_entity_types,
|
||||
relationship_types_string=all_relationship_types,
|
||||
available_tools=available_tools,
|
||||
)
|
||||
orchestration_next_step_purpose_prompt = base_next_step_purpose_prompt.build(
|
||||
question=prompt_question,
|
||||
reasoning_result=reasoning_result,
|
||||
tool_calls=tool_calls_string,
|
||||
)
|
||||
|
||||
purpose_tokens: list[str] = [""]
|
||||
purpose = ""
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.DEEP]:
|
||||
|
||||
try:
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ReasoningStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
purpose_tokens, _, _ = run_with_timeout(
|
||||
TF_DR_TIMEOUT_LONG,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
decision_system_prompt,
|
||||
orchestration_next_step_purpose_prompt,
|
||||
uploaded_image_context=uploaded_image_context,
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
answer_piece=StreamingType.REASONING_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in orchestration next step purpose.")
|
||||
raise e
|
||||
|
||||
purpose = cast(str, merge_content(*purpose_tokens))
|
||||
|
||||
elif research_type == ResearchType.FAST:
|
||||
purpose = f"Answering the question using the {next_tool_name}"
|
||||
|
||||
if not next_tool_name:
|
||||
raise ValueError("The next step has not been defined. This should not happen.")
|
||||
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[next_tool_name],
|
||||
query_list=query_list or [],
|
||||
iteration_nr=iteration_nr,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="orchestrator",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
plan_of_record=plan_of_record,
|
||||
remaining_time_budget=remaining_time_budget,
|
||||
iteration_instructions=[
|
||||
IterationInstructions(
|
||||
iteration_nr=iteration_nr,
|
||||
plan=plan_of_record.plan if plan_of_record else None,
|
||||
reasoning=reasoning_result,
|
||||
purpose=purpose,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,423 +0,0 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_CHAT_HISTORY_MESSAGES
|
||||
from onyx.agents.agent_search.dr.constants import MAX_NUM_CLOSER_SUGGESTIONS
|
||||
from onyx.agents.agent_search.dr.enums import DRPath
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.models import TestInfoCompleteResponse
|
||||
from onyx.agents.agent_search.dr.states import FinalUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.states import OrchestrationUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import get_chat_history_string
|
||||
from onyx.agents.agent_search.dr.utils import get_prompt_question
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.prompts.chat_prompts import PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
from onyx.prompts.dr_prompts import TEST_INFO_COMPLETE_PROMPT
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import StreamingType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
|
||||
|
||||
def insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = extract_citation_numbers(final_answer)
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def closer(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> FinalUpdate | OrchestrationUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
assistant_system_prompt: str = state.assistant_system_prompt or ""
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
uploaded_context = state.uploaded_test_context or ""
|
||||
|
||||
clarification = state.clarification
|
||||
prompt_question = get_prompt_question(base_question, clarification)
|
||||
|
||||
chat_history_string = (
|
||||
get_chat_history_string(
|
||||
graph_config.inputs.prompt_builder.message_history,
|
||||
MAX_CHAT_HISTORY_MESSAGES,
|
||||
)
|
||||
or "(No chat history yet available)"
|
||||
)
|
||||
|
||||
aggregated_context_w_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
aggregated_context_wo_docs = aggregate_context(
|
||||
state.iteration_responses, include_documents=False
|
||||
)
|
||||
|
||||
iteration_responses_w_docs_string = aggregated_context_w_docs.context
|
||||
iteration_responses_wo_docs_string = aggregated_context_wo_docs.context
|
||||
all_cited_documents = aggregated_context_w_docs.cited_documents
|
||||
|
||||
num_closer_suggestions = state.num_closer_suggestions
|
||||
|
||||
if (
|
||||
num_closer_suggestions < MAX_NUM_CLOSER_SUGGESTIONS
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
test_info_complete_prompt = TEST_INFO_COMPLETE_PROMPT.build(
|
||||
base_question=prompt_question,
|
||||
questions_answers_claims=iteration_responses_wo_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
high_level_plan=(
|
||||
state.plan_of_record.plan
|
||||
if state.plan_of_record
|
||||
else "No plan available"
|
||||
),
|
||||
)
|
||||
|
||||
test_info_complete_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
test_info_complete_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
schema=TestInfoCompleteResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1000,
|
||||
)
|
||||
|
||||
if test_info_complete_json.complete:
|
||||
pass
|
||||
|
||||
else:
|
||||
return OrchestrationUpdate(
|
||||
tools_used=[DRPath.ORCHESTRATOR.value],
|
||||
query_list=[],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
gaps=test_info_complete_json.gaps,
|
||||
num_closer_suggestions=num_closer_suggestions + 1,
|
||||
)
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
all_cited_documents
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
MessageStart(
|
||||
content="",
|
||||
final_documents=retrieved_search_docs,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
if research_type in [ResearchType.THOUGHTFUL, ResearchType.FAST]:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_WITHOUT_SUB_ANSWERS
|
||||
elif research_type == ResearchType.DEEP:
|
||||
final_answer_base_prompt = FINAL_ANSWER_PROMPT_W_SUB_ANSWERS
|
||||
else:
|
||||
raise ValueError(f"Invalid research type: {research_type}")
|
||||
|
||||
estimated_final_answer_prompt_tokens = check_number_of_tokens(
|
||||
final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_w_docs_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
)
|
||||
|
||||
# for DR, rely only on sub-answers and claims to save tokens if context is too long
|
||||
# TODO: consider compression step for Thoughtful mode if context is too long.
|
||||
# Should generally not be the case though.
|
||||
|
||||
max_allowed_input_tokens = graph_config.tooling.primary_llm.config.max_input_tokens
|
||||
|
||||
if (
|
||||
estimated_final_answer_prompt_tokens > 0.8 * max_allowed_input_tokens
|
||||
and research_type == ResearchType.DEEP
|
||||
):
|
||||
iteration_responses_string = iteration_responses_wo_docs_string
|
||||
else:
|
||||
iteration_responses_string = iteration_responses_w_docs_string
|
||||
|
||||
final_answer_prompt = final_answer_base_prompt.build(
|
||||
base_question=prompt_question,
|
||||
iteration_responses_string=iteration_responses_string,
|
||||
chat_history_string=chat_history_string,
|
||||
uploaded_context=uploaded_context,
|
||||
)
|
||||
|
||||
if graph_config.inputs.project_instructions:
|
||||
assistant_system_prompt = (
|
||||
assistant_system_prompt
|
||||
+ PROJECT_INSTRUCTIONS_SEPARATOR
|
||||
+ (graph_config.inputs.project_instructions or "")
|
||||
)
|
||||
|
||||
all_context_llmdocs = [
|
||||
llm_doc_from_inference_section(inference_section)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
try:
|
||||
streamed_output, _, citation_infos = run_with_timeout(
|
||||
int(3 * TF_DR_TIMEOUT_LONG),
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt,
|
||||
final_answer_prompt + (assistant_task_prompt or ""),
|
||||
),
|
||||
event_name="basic_response",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=int(2 * TF_DR_TIMEOUT_LONG),
|
||||
answer_piece=StreamingType.MESSAGE_DELTA.value,
|
||||
ind=current_step_nr,
|
||||
context_docs=all_context_llmdocs,
|
||||
replace_citations=True,
|
||||
# max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
final_answer = "".join(streamed_output)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error in consolidate_research: {e}")
|
||||
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
write_custom_event(current_step_nr, CitationStart(), writer)
|
||||
write_custom_event(current_step_nr, CitationDelta(citations=citation_infos), writer)
|
||||
write_custom_event(current_step_nr, SectionEnd(), writer)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
# Log the research agent steps
|
||||
# save_iteration(
|
||||
# state,
|
||||
# graph_config,
|
||||
# aggregated_context,
|
||||
# final_answer,
|
||||
# all_cited_documents,
|
||||
# is_internet_marker_dict,
|
||||
# )
|
||||
|
||||
return FinalUpdate(
|
||||
final_answer=final_answer,
|
||||
all_cited_documents=all_cited_documents,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="closer",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,248 +0,0 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchAnswerPurpose
|
||||
from onyx.agents.agent_search.dr.models import AggregatedDRContext
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.states import MainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.models import (
|
||||
GeneratedImageFullResult,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.utils import aggregate_context
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import parse_plan_to_dict
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.chat import create_search_doc_from_inference_section
|
||||
from onyx.db.chat import update_db_session_with_messages
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ResearchAgentIteration
|
||||
from onyx.db.models import ResearchAgentIterationSubStep
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_citation_numbers(text: str) -> list[int]:
|
||||
"""
|
||||
Extract all citation numbers from text in the format [[<number>]] or [[<number_1>, <number_2>, ...]].
|
||||
Returns a list of all unique citation numbers found.
|
||||
"""
|
||||
# Pattern to match [[number]] or [[number1, number2, ...]]
|
||||
pattern = r"\[\[(\d+(?:,\s*\d+)*)\]\]"
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
cited_numbers = []
|
||||
for match in matches:
|
||||
# Split by comma and extract all numbers
|
||||
numbers = [int(num.strip()) for num in match.split(",")]
|
||||
cited_numbers.extend(numbers)
|
||||
|
||||
return list(set(cited_numbers)) # Return unique numbers
|
||||
|
||||
|
||||
def replace_citation_with_link(match: re.Match[str], docs: list[DbSearchDoc]) -> str:
|
||||
citation_content = match.group(1) # e.g., "3" or "3, 5, 7"
|
||||
numbers = [int(num.strip()) for num in citation_content.split(",")]
|
||||
|
||||
# For multiple citations like [[3, 5, 7]], create separate linked citations
|
||||
linked_citations = []
|
||||
for num in numbers:
|
||||
if num - 1 < len(docs): # Check bounds
|
||||
link = docs[num - 1].link or ""
|
||||
linked_citations.append(f"[[{num}]]({link})")
|
||||
else:
|
||||
linked_citations.append(f"[[{num}]]") # No link if out of bounds
|
||||
|
||||
return "".join(linked_citations)
|
||||
|
||||
|
||||
def _insert_chat_message_search_doc_pair(
|
||||
message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Insert a pair of message_id and search_doc_id into the chat_message__search_doc table.
|
||||
|
||||
Args:
|
||||
message_id: The ID of the chat message
|
||||
search_doc_id: The ID of the search document
|
||||
db_session: The database session
|
||||
"""
|
||||
for search_doc_id in search_doc_ids:
|
||||
chat_message_search_doc = ChatMessage__SearchDoc(
|
||||
chat_message_id=message_id, search_doc_id=search_doc_id
|
||||
)
|
||||
db_session.add(chat_message_search_doc)
|
||||
|
||||
|
||||
def save_iteration(
|
||||
state: MainState,
|
||||
graph_config: GraphConfig,
|
||||
aggregated_context: AggregatedDRContext,
|
||||
final_answer: str,
|
||||
all_cited_documents: list[InferenceSection],
|
||||
is_internet_marker_dict: dict[str, bool],
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
db_session = graph_config.persistence.db_session
|
||||
message_id = graph_config.persistence.message_id
|
||||
research_type = graph_config.behavior.research_type
|
||||
db_session = graph_config.persistence.db_session
|
||||
|
||||
# first, insert the search_docs
|
||||
search_docs = [
|
||||
create_search_doc_from_inference_section(
|
||||
inference_section=inference_section,
|
||||
is_internet=is_internet_marker_dict.get(
|
||||
inference_section.center_chunk.document_id, False
|
||||
), # TODO: revisit
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
for inference_section in all_cited_documents
|
||||
]
|
||||
|
||||
# then, map_search_docs to message
|
||||
_insert_chat_message_search_doc_pair(
|
||||
message_id, [search_doc.id for search_doc in search_docs], db_session
|
||||
)
|
||||
|
||||
# lastly, insert the citations
|
||||
citation_dict: dict[int, int] = {}
|
||||
cited_doc_nrs = _extract_citation_numbers(final_answer)
|
||||
if search_docs:
|
||||
for cited_doc_nr in cited_doc_nrs:
|
||||
citation_dict[cited_doc_nr] = search_docs[cited_doc_nr - 1].id
|
||||
|
||||
# TODO: generate plan as dict in the first place
|
||||
plan_of_record = state.plan_of_record.plan if state.plan_of_record else ""
|
||||
plan_of_record_dict = parse_plan_to_dict(plan_of_record)
|
||||
|
||||
# Update the chat message and its parent message in database
|
||||
update_db_session_with_messages(
|
||||
db_session=db_session,
|
||||
chat_message_id=message_id,
|
||||
chat_session_id=graph_config.persistence.chat_session_id,
|
||||
is_agentic=graph_config.behavior.use_agentic_search,
|
||||
message=final_answer,
|
||||
citations=citation_dict,
|
||||
research_type=research_type,
|
||||
research_plan=plan_of_record_dict,
|
||||
final_documents=search_docs,
|
||||
update_parent_message=True,
|
||||
research_answer_purpose=ResearchAnswerPurpose.ANSWER,
|
||||
token_count=num_tokens,
|
||||
)
|
||||
|
||||
for iteration_preparation in state.iteration_instructions:
|
||||
research_agent_iteration_step = ResearchAgentIteration(
|
||||
primary_question_id=message_id,
|
||||
reasoning=iteration_preparation.reasoning,
|
||||
purpose=iteration_preparation.purpose,
|
||||
iteration_nr=iteration_preparation.iteration_nr,
|
||||
)
|
||||
db_session.add(research_agent_iteration_step)
|
||||
|
||||
for iteration_answer in aggregated_context.global_iteration_responses:
|
||||
|
||||
retrieved_search_docs = convert_inference_sections_to_search_docs(
|
||||
list(iteration_answer.cited_documents.values())
|
||||
)
|
||||
|
||||
# Convert SavedSearchDoc objects to JSON-serializable format
|
||||
serialized_search_docs = [doc.model_dump() for doc in retrieved_search_docs]
|
||||
|
||||
research_agent_iteration_sub_step = ResearchAgentIterationSubStep(
|
||||
primary_question_id=message_id,
|
||||
iteration_nr=iteration_answer.iteration_nr,
|
||||
iteration_sub_step_nr=iteration_answer.parallelization_nr,
|
||||
sub_step_instructions=iteration_answer.question,
|
||||
sub_step_tool_id=iteration_answer.tool_id,
|
||||
sub_answer=iteration_answer.answer,
|
||||
reasoning=iteration_answer.reasoning,
|
||||
claims=iteration_answer.claims,
|
||||
cited_doc_results=serialized_search_docs,
|
||||
generated_images=(
|
||||
GeneratedImageFullResult(images=iteration_answer.generated_images)
|
||||
if iteration_answer.generated_images
|
||||
else None
|
||||
),
|
||||
additional_data=iteration_answer.additional_data,
|
||||
queries=iteration_answer.queries,
|
||||
)
|
||||
db_session.add(research_agent_iteration_sub_step)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def logging(
|
||||
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to close the DR process and finalize the answer.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
# TODO: generate final answer using all the previous steps
|
||||
# (right now, answers from each step are concatenated onto each other)
|
||||
# Also, add missing fields once usage in UI is clear.
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = state.original_question
|
||||
if not base_question:
|
||||
raise ValueError("Question is required for closer")
|
||||
|
||||
aggregated_context = aggregate_context(
|
||||
state.iteration_responses, include_documents=True
|
||||
)
|
||||
|
||||
all_cited_documents = aggregated_context.cited_documents
|
||||
|
||||
is_internet_marker_dict = aggregated_context.is_internet_marker_dict
|
||||
|
||||
final_answer = state.final_answer or ""
|
||||
llm_provider = graph_config.tooling.primary_llm.config.model_provider
|
||||
llm_model_name = graph_config.tooling.primary_llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
num_tokens = len(llm_tokenizer.encode(final_answer or ""))
|
||||
|
||||
write_custom_event(current_step_nr, OverallStop(), writer)
|
||||
|
||||
# Log the research agent steps
|
||||
save_iteration(
|
||||
state,
|
||||
graph_config,
|
||||
aggregated_context,
|
||||
final_answer,
|
||||
all_cited_documents,
|
||||
is_internet_marker_dict,
|
||||
num_tokens,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
node_name="logger",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,132 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.chat_utils import saved_search_docs_from_llm_docs
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
PassThroughAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.server.query_and_chat.streaming_models import CitationDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CitationStart
|
||||
from onyx.server.query_and_chat.streaming_models import MessageDelta
|
||||
from onyx.server.query_and_chat.streaming_models import MessageStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicSearchProcessedStreamResults(BaseModel):
|
||||
ai_message_chunk: AIMessageChunk = AIMessageChunk(content="")
|
||||
full_answer: str | None = None
|
||||
cited_references: list[InferenceSection] = []
|
||||
retrieved_documents: list[LlmDoc] = []
|
||||
|
||||
|
||||
def process_llm_stream(
|
||||
messages: Iterator[BaseMessage],
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
ind: int,
|
||||
search_results: list[LlmDoc] | None = None,
|
||||
generate_final_answer: bool = False,
|
||||
chat_message_id: str | None = None,
|
||||
) -> BasicSearchProcessedStreamResults:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
if search_results:
|
||||
answer_handler: AnswerResponseHandler = CitationResponseHandler(
|
||||
context_docs=search_results,
|
||||
doc_id_to_rank_map=map_document_id_order(search_results),
|
||||
)
|
||||
else:
|
||||
answer_handler = PassThroughAnswerResponseHandler()
|
||||
|
||||
full_answer = ""
|
||||
start_final_answer_streaming_set = False
|
||||
# Accumulate citation infos if handler emits them
|
||||
collected_citation_infos: list[CitationInfo] = []
|
||||
|
||||
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
|
||||
# the stream will contain AIMessageChunks with tool call information.
|
||||
for message in messages:
|
||||
|
||||
answer_piece = message.content
|
||||
if not isinstance(answer_piece, str):
|
||||
# this is only used for logging, so fine to
|
||||
# just add the string representation
|
||||
answer_piece = str(answer_piece)
|
||||
full_answer += answer_piece
|
||||
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
tool_call_chunk += message # type: ignore
|
||||
elif should_stream_answer:
|
||||
for response_part in answer_handler.handle_response_part(message):
|
||||
|
||||
# only stream out answer parts
|
||||
if (
|
||||
isinstance(response_part, (OnyxAnswerPiece, AgentAnswerPiece))
|
||||
and generate_final_answer
|
||||
and response_part.answer_piece
|
||||
):
|
||||
if chat_message_id is None:
|
||||
raise ValueError(
|
||||
"chat_message_id is required when generating final answer"
|
||||
)
|
||||
|
||||
if not start_final_answer_streaming_set:
|
||||
# Convert LlmDocs to SavedSearchDocs
|
||||
saved_search_docs = saved_search_docs_from_llm_docs(
|
||||
search_results
|
||||
)
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageStart(content="", final_documents=saved_search_docs),
|
||||
writer,
|
||||
)
|
||||
start_final_answer_streaming_set = True
|
||||
|
||||
write_custom_event(
|
||||
ind,
|
||||
MessageDelta(content=response_part.answer_piece),
|
||||
writer,
|
||||
)
|
||||
# collect citation info objects
|
||||
elif isinstance(response_part, CitationInfo):
|
||||
collected_citation_infos.append(response_part)
|
||||
|
||||
if generate_final_answer and start_final_answer_streaming_set:
|
||||
# start_final_answer_streaming_set is only set if the answer is verbal and not a tool call
|
||||
write_custom_event(
|
||||
ind,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
# Emit citations section if any were collected
|
||||
if collected_citation_infos:
|
||||
write_custom_event(ind, CitationStart(), writer)
|
||||
write_custom_event(
|
||||
ind, CitationDelta(citations=collected_citation_infos), writer
|
||||
)
|
||||
write_custom_event(ind, SectionEnd(), writer)
|
||||
|
||||
logger.debug(f"Full answer: {full_answer}")
|
||||
return BasicSearchProcessedStreamResults(
|
||||
ai_message_chunk=cast(AIMessageChunk, tool_call_chunk), full_answer=full_answer
|
||||
)
|
||||
@@ -1,82 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.core_state import CoreState
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import IterationInstructions
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationClarificationInfo
|
||||
from onyx.agents.agent_search.dr.models import OrchestrationPlan
|
||||
from onyx.agents.agent_search.dr.models import OrchestratorTool
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
|
||||
### States ###
|
||||
|
||||
|
||||
class LoggerUpdate(BaseModel):
|
||||
log_messages: Annotated[list[str], add] = []
|
||||
|
||||
|
||||
class OrchestrationUpdate(LoggerUpdate):
|
||||
tools_used: Annotated[list[str], add] = []
|
||||
query_list: list[str] = []
|
||||
iteration_nr: int = 0
|
||||
current_step_nr: int = 1
|
||||
plan_of_record: OrchestrationPlan | None = None # None for Thoughtful
|
||||
remaining_time_budget: float = 2.0 # set by default to about 2 searches
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
gaps: list[str] = (
|
||||
[]
|
||||
) # gaps that may be identified by the closer before being able to answer the question.
|
||||
iteration_instructions: Annotated[list[IterationInstructions], add] = []
|
||||
|
||||
|
||||
class OrchestrationSetup(OrchestrationUpdate):
|
||||
original_question: str | None = None
|
||||
chat_history_string: str | None = None
|
||||
clarification: OrchestrationClarificationInfo | None = None
|
||||
available_tools: dict[str, OrchestratorTool] | None = None
|
||||
num_closer_suggestions: int = 0 # how many times the closer was suggested
|
||||
|
||||
active_source_types: list[DocumentSource] | None = None
|
||||
active_source_types_descriptions: str | None = None
|
||||
assistant_system_prompt: str | None = None
|
||||
assistant_task_prompt: str | None = None
|
||||
uploaded_test_context: str | None = None
|
||||
uploaded_image_context: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class AnswerUpdate(LoggerUpdate):
|
||||
iteration_responses: Annotated[list[IterationAnswer], add] = []
|
||||
|
||||
|
||||
class FinalUpdate(LoggerUpdate):
|
||||
final_answer: str | None = None
|
||||
all_cited_documents: list[InferenceSection] = []
|
||||
|
||||
|
||||
## Graph Input State
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
OrchestrationSetup,
|
||||
AnswerUpdate,
|
||||
FinalUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
class MainOutput(TypedDict):
|
||||
log_messages: list[str]
|
||||
final_answer: str | None
|
||||
all_cited_documents: list[InferenceSection]
|
||||
@@ -1,47 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
logger.debug(f"Search start for Basic Search {iteration_nr} at {datetime.now()}")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolStart(
|
||||
is_internet_search=False,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,286 +0,0 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.enums import ResearchType
|
||||
from onyx.agents.agent_search.dr.models import BaseSearchProcessingResponse
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.models import SearchAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.utils import convert_inference_sections_to_search_docs
|
||||
from onyx.agents.agent_search.dr.utils import extract_document_citations
|
||||
from onyx.agents.agent_search.kb_search.graph_utils import build_document_context
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import invoke_llm_json
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.agents.agent_search.utils import create_question_prompt
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.connector import DocumentSource
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.prompts.dr_prompts import BASE_SEARCH_PROCESSING_PROMPT
|
||||
from onyx.prompts.dr_prompts import INTERNAL_SEARCH_PROMPTS
|
||||
from onyx.secondary_llm_flows.source_filter import strings_to_document_sources
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDelta
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_search(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
assistant_system_prompt = state.assistant_system_prompt
|
||||
assistant_task_prompt = state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
research_type = graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
elif len(state.tools_used) == 0:
|
||||
raise ValueError("tools_used is empty")
|
||||
|
||||
search_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
search_tool = cast(SearchTool, search_tool_info.tool_object)
|
||||
force_use_tool = graph_config.tooling.force_use_tool
|
||||
|
||||
# sanity check
|
||||
if search_tool != graph_config.tooling.search_tool:
|
||||
raise ValueError("search_tool does not match the configured search tool")
|
||||
|
||||
# rewrite query and identify source types
|
||||
active_source_types_str = ", ".join(
|
||||
[source.value for source in state.active_source_types or []]
|
||||
)
|
||||
|
||||
base_search_processing_prompt = BASE_SEARCH_PROCESSING_PROMPT.build(
|
||||
active_source_types_str=active_source_types_str,
|
||||
branch_query=branch_query,
|
||||
current_time=datetime.now().strftime("%Y-%m-%d %H:%M"),
|
||||
)
|
||||
|
||||
try:
|
||||
search_processing = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, base_search_processing_prompt
|
||||
),
|
||||
schema=BaseSearchProcessingResponse,
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
# max_tokens=100,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not process query: {e}")
|
||||
raise e
|
||||
|
||||
rewritten_query = search_processing.rewritten_query
|
||||
|
||||
# give back the query so we can render it in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[rewritten_query],
|
||||
documents=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
implied_start_date = search_processing.time_filter
|
||||
|
||||
# Validate time_filter format if it exists
|
||||
implied_time_filter = None
|
||||
if implied_start_date:
|
||||
|
||||
# Check if time_filter is in YYYY-MM-DD format
|
||||
date_pattern = r"^\d{4}-\d{2}-\d{2}$"
|
||||
if re.match(date_pattern, implied_start_date):
|
||||
implied_time_filter = datetime.strptime(implied_start_date, "%Y-%m-%d")
|
||||
|
||||
specified_source_types: list[DocumentSource] | None = (
|
||||
strings_to_document_sources(search_processing.specified_source_types)
|
||||
if search_processing.specified_source_types
|
||||
else None
|
||||
)
|
||||
|
||||
if specified_source_types is not None and len(specified_source_types) == 0:
|
||||
specified_source_types = None
|
||||
|
||||
logger.debug(
|
||||
f"Search start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
retrieved_docs: list[InferenceSection] = []
|
||||
callback_container: list[list[InferenceSection]] = []
|
||||
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
if force_use_tool.override_kwargs and isinstance(
|
||||
force_use_tool.override_kwargs, SearchToolOverrideKwargs
|
||||
):
|
||||
override_kwargs = force_use_tool.override_kwargs
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
project_id = override_kwargs.project_id
|
||||
|
||||
# new db session to avoid concurrency issues
|
||||
with get_session_with_current_tenant() as search_db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=rewritten_query,
|
||||
document_sources=specified_source_types,
|
||||
time_filter=implied_time_filter,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=search_db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=True,
|
||||
original_query=rewritten_query,
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
response = cast(SearchResponseSummary, tool_response.response)
|
||||
retrieved_docs = response.top_sections
|
||||
|
||||
break
|
||||
|
||||
# render the retrieved docs in the UI
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SearchToolDelta(
|
||||
queries=[],
|
||||
documents=convert_inference_sections_to_search_docs(
|
||||
retrieved_docs, is_internet=False
|
||||
),
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
document_texts_list = []
|
||||
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15]):
|
||||
if not isinstance(retrieved_doc, (InferenceSection, LlmDoc)):
|
||||
raise ValueError(f"Unexpected document type: {type(retrieved_doc)}")
|
||||
chunk_text = build_document_context(retrieved_doc, doc_num + 1)
|
||||
document_texts_list.append(chunk_text)
|
||||
|
||||
document_texts = "\n\n".join(document_texts_list)
|
||||
|
||||
logger.debug(
|
||||
f"Search end/LLM start for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Built prompt
|
||||
|
||||
if research_type == ResearchType.DEEP:
|
||||
search_prompt = INTERNAL_SEARCH_PROMPTS[research_type].build(
|
||||
search_query=branch_query,
|
||||
base_question=base_question,
|
||||
document_text=document_texts,
|
||||
)
|
||||
|
||||
# Run LLM
|
||||
|
||||
# search_answer_json = None
|
||||
search_answer_json = invoke_llm_json(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=create_question_prompt(
|
||||
assistant_system_prompt, search_prompt + (assistant_task_prompt or "")
|
||||
),
|
||||
schema=SearchAnswer,
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
# max_tokens=1500,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"LLM/all done for Standard Search {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get cited documents
|
||||
answer_string = search_answer_json.answer
|
||||
claims = search_answer_json.claims or []
|
||||
reasoning = search_answer_json.reasoning
|
||||
# answer_string = ""
|
||||
# claims = []
|
||||
|
||||
(
|
||||
citation_numbers,
|
||||
answer_string,
|
||||
claims,
|
||||
) = extract_document_citations(answer_string, claims)
|
||||
|
||||
if citation_numbers and (
|
||||
(max(citation_numbers) > len(retrieved_docs)) or min(citation_numbers) < 1
|
||||
):
|
||||
raise ValueError("Citation numbers are out of range for retrieved docs.")
|
||||
|
||||
cited_documents = {
|
||||
citation_number: retrieved_docs[citation_number - 1]
|
||||
for citation_number in citation_numbers
|
||||
}
|
||||
|
||||
else:
|
||||
answer_string = ""
|
||||
claims = []
|
||||
cited_documents = {
|
||||
doc_num + 1: retrieved_doc
|
||||
for doc_num, retrieved_doc in enumerate(retrieved_docs[:15])
|
||||
}
|
||||
reasoning = ""
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=search_tool_info.llm_path,
|
||||
tool_id=search_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=claims,
|
||||
cited_documents=cited_documents,
|
||||
reasoning=reasoning,
|
||||
additional_data=None,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="searching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,77 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
[update.question for update in new_updates]
|
||||
doc_lists = [list(update.cited_documents.values()) for update in new_updates]
|
||||
|
||||
doc_list = []
|
||||
|
||||
for xs in doc_lists:
|
||||
for x in xs:
|
||||
doc_list.append(x)
|
||||
|
||||
# Convert InferenceSections to SavedSearchDocs
|
||||
search_docs = SearchDoc.from_chunks_or_sections(doc_list)
|
||||
retrieved_saved_search_docs = [
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
for search_doc in search_docs
|
||||
]
|
||||
|
||||
for retrieved_saved_search_doc in retrieved_saved_search_docs:
|
||||
retrieved_saved_search_doc.is_internet = False
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="basic_search",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,50 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_1_branch import (
|
||||
basic_search_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_2_act import (
|
||||
basic_search,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_basic_search_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.basic_search.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_basic_search_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Web Search Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", basic_search_branch)
|
||||
|
||||
graph.add_node("act", basic_search)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,30 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
current_step_nr=state.current_step_nr,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -1,36 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,169 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_LONG
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomTool
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
||||
from onyx.tools.tool_implementations.mcp.mcp_tool import MCP_TOOL_RESPONSE_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
custom_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
custom_tool_name = custom_tool_info.name
|
||||
custom_tool = cast(CustomTool, custom_tool_info.tool_object)
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=custom_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[custom_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_LONG,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = custom_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
response_summary: CustomToolCallSummary | None = None
|
||||
for tool_response in custom_tool.run(**tool_args):
|
||||
if tool_response.id in {CUSTOM_TOOL_RESPONSE_ID, MCP_TOOL_RESPONSE_ID}:
|
||||
response_summary = cast(CustomToolCallSummary, tool_response.response)
|
||||
break
|
||||
|
||||
if not response_summary:
|
||||
raise ValueError("Custom tool did not return a valid response summary")
|
||||
|
||||
# summarise tool result
|
||||
if not response_summary.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
if response_summary.response_type == "json":
|
||||
tool_result_str = json.dumps(response_summary.tool_result, ensure_ascii=False)
|
||||
elif response_summary.response_type in {"image", "csv"}:
|
||||
tool_result_str = f"{response_summary.response_type} files: {response_summary.tool_result.file_ids}"
|
||||
else:
|
||||
tool_result_str = str(response_summary.tool_result)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {custom_tool_name}\n"
|
||||
f"Description: {custom_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
tool_summary_prompt = CUSTOM_TOOL_USE_PROMPT.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
# get file_ids:
|
||||
file_ids = None
|
||||
if response_summary.response_type in {"image", "csv"} and hasattr(
|
||||
response_summary.tool_result, "file_ids"
|
||||
):
|
||||
file_ids = response_summary.tool_result.file_ids
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {custom_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=custom_tool_name,
|
||||
tool_id=custom_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type=response_summary.response_type,
|
||||
data=response_summary.tool_result,
|
||||
file_ids=file_ids,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,82 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def custom_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=new_update.file_ids,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_1_branch import (
|
||||
custom_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_2_act import (
|
||||
custom_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_3_reduce import (
|
||||
custom_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.custom_tool.dr_custom_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_custom_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", custom_tool_branch)
|
||||
|
||||
graph.add_node("act", custom_tool_act)
|
||||
|
||||
graph.add_node("reducer", custom_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,36 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Search start for Generic Tool {iteration_nr} at {datetime.now()}")
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="generic_internal_tool",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,149 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import IterationAnswer
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import TF_DR_TIMEOUT_SHORT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_PREP_PROMPT
|
||||
from onyx.prompts.dr_prompts import CUSTOM_TOOL_USE_PROMPT
|
||||
from onyx.prompts.dr_prompts import OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_act(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
generic_internal_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
generic_internal_tool_name = generic_internal_tool_info.llm_path
|
||||
generic_internal_tool = generic_internal_tool_info.tool_object
|
||||
|
||||
if generic_internal_tool is None:
|
||||
raise ValueError("generic_internal_tool is not set")
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
base_question = graph_config.inputs.prompt_builder.raw_user_query
|
||||
|
||||
logger.debug(
|
||||
f"Tool call start for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# get tool call args
|
||||
tool_args: dict | None = None
|
||||
if graph_config.tooling.using_tool_calling_llm:
|
||||
# get tool call args from tool-calling LLM
|
||||
tool_use_prompt = CUSTOM_TOOL_PREP_PROMPT.build(
|
||||
query=branch_query,
|
||||
base_question=base_question,
|
||||
tool_description=generic_internal_tool_info.description,
|
||||
)
|
||||
tool_calling_msg = graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_use_prompt,
|
||||
tools=[generic_internal_tool.tool_definition()],
|
||||
tool_choice="required",
|
||||
timeout_override=TF_DR_TIMEOUT_SHORT,
|
||||
)
|
||||
|
||||
# make sure we got a tool call
|
||||
if (
|
||||
isinstance(tool_calling_msg, AIMessage)
|
||||
and len(tool_calling_msg.tool_calls) == 1
|
||||
):
|
||||
tool_args = tool_calling_msg.tool_calls[0]["args"]
|
||||
else:
|
||||
logger.warning("Tool-calling LLM did not emit a tool call")
|
||||
|
||||
if tool_args is None:
|
||||
# get tool call args from non-tool-calling LLM or for failed tool-calling LLM
|
||||
tool_args = generic_internal_tool.get_args_for_non_tool_calling_llm(
|
||||
query=branch_query,
|
||||
history=[],
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
force_run=True,
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise ValueError("Failed to obtain tool arguments from LLM")
|
||||
|
||||
# run the tool
|
||||
tool_responses = list(generic_internal_tool.run(**tool_args))
|
||||
final_data = generic_internal_tool.final_result(*tool_responses)
|
||||
tool_result_str = json.dumps(final_data, ensure_ascii=False)
|
||||
|
||||
tool_str = (
|
||||
f"Tool used: {generic_internal_tool.display_name}\n"
|
||||
f"Description: {generic_internal_tool_info.description}\n"
|
||||
f"Result: {tool_result_str}"
|
||||
)
|
||||
|
||||
if generic_internal_tool.display_name == "Okta Profile":
|
||||
tool_prompt = OKTA_TOOL_USE_SPECIAL_PROMPT
|
||||
else:
|
||||
tool_prompt = CUSTOM_TOOL_USE_PROMPT
|
||||
|
||||
tool_summary_prompt = tool_prompt.build(
|
||||
query=branch_query, base_question=base_question, tool_response=tool_str
|
||||
)
|
||||
answer_string = str(
|
||||
graph_config.tooling.primary_llm.invoke_langchain(
|
||||
tool_summary_prompt, timeout_override=TF_DR_TIMEOUT_SHORT
|
||||
).content
|
||||
).strip()
|
||||
|
||||
logger.debug(
|
||||
f"Tool call end for {generic_internal_tool_name} {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=generic_internal_tool.llm_name,
|
||||
tool_id=generic_internal_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning="",
|
||||
additional_data=None,
|
||||
response_type="text", # TODO: convert all response types to enums
|
||||
data=answer_string,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="tool_calling",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,82 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import CustomToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generic_internal_tool_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a generic tool call as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
|
||||
for new_update in new_updates:
|
||||
|
||||
if not new_update.response_type:
|
||||
raise ValueError("Response type is not returned.")
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolStart(
|
||||
tool_name=new_update.tool,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
CustomToolDelta(
|
||||
tool_name=new_update.tool,
|
||||
response_type=new_update.response_type,
|
||||
data=new_update.data,
|
||||
file_ids=[],
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="custom_tool",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,28 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import (
|
||||
SubAgentInput,
|
||||
)
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:1] # no parallel call for now
|
||||
)
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_1_branch import (
|
||||
generic_internal_tool_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_2_act import (
|
||||
generic_internal_tool_act,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_3_reduce import (
|
||||
generic_internal_tool_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.generic_internal_tool.dr_generic_internal_tool_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_generic_internal_tool_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Generic Tool Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", generic_internal_tool_branch)
|
||||
|
||||
graph.add_node("act", generic_internal_tool_act)
|
||||
|
||||
graph.add_node("reducer", generic_internal_tool_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,45 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.states import LoggerUpdate
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation_branch(
|
||||
state: SubAgentInput, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> LoggerUpdate:
|
||||
"""
|
||||
LangGraph node to perform a image generation as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
|
||||
logger.debug(f"Image generation start {iteration_nr} at {datetime.now()}")
|
||||
|
||||
# tell frontend that we are starting the image generation tool
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolStart(),
|
||||
writer,
|
||||
)
|
||||
|
||||
return LoggerUpdate(
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="branching",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,189 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.models import IterationAnswer
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchUpdate
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeartbeat
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_HEARTBEAT_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
IMAGE_GENERATION_RESPONSE_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import ImageShape
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def image_generation(
|
||||
state: BranchInput,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> BranchUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
iteration_nr = state.iteration_nr
|
||||
parallelization_nr = state.parallelization_nr
|
||||
state.assistant_system_prompt
|
||||
state.assistant_task_prompt
|
||||
|
||||
branch_query = state.branch_question
|
||||
if not branch_query:
|
||||
raise ValueError("branch_query is not set")
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
graph_config.inputs.prompt_builder.raw_user_query
|
||||
graph_config.behavior.research_type
|
||||
|
||||
if not state.available_tools:
|
||||
raise ValueError("available_tools is not set")
|
||||
|
||||
image_tool_info = state.available_tools[state.tools_used[-1]]
|
||||
image_tool = cast(ImageGenerationTool, image_tool_info.tool_object)
|
||||
|
||||
image_prompt = branch_query
|
||||
requested_shape: ImageShape | None = None
|
||||
|
||||
try:
|
||||
parsed_query = json.loads(branch_query)
|
||||
except json.JSONDecodeError:
|
||||
parsed_query = None
|
||||
|
||||
if isinstance(parsed_query, dict):
|
||||
prompt_from_llm = parsed_query.get("prompt")
|
||||
if isinstance(prompt_from_llm, str) and prompt_from_llm.strip():
|
||||
image_prompt = prompt_from_llm.strip()
|
||||
|
||||
raw_shape = parsed_query.get("shape")
|
||||
if isinstance(raw_shape, str):
|
||||
try:
|
||||
requested_shape = ImageShape(raw_shape)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Received unsupported image shape '%s' from LLM. Falling back to square.",
|
||||
raw_shape,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Image generation start for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Generate images using the image generation tool
|
||||
image_generation_responses: list[ImageGenerationResponse] = []
|
||||
|
||||
if requested_shape is not None:
|
||||
tool_iterator = image_tool.run(
|
||||
prompt=image_prompt,
|
||||
shape=requested_shape.value,
|
||||
)
|
||||
else:
|
||||
tool_iterator = image_tool.run(prompt=image_prompt)
|
||||
|
||||
for tool_response in tool_iterator:
|
||||
if tool_response.id == IMAGE_GENERATION_HEARTBEAT_ID:
|
||||
# Stream heartbeat to frontend
|
||||
write_custom_event(
|
||||
state.current_step_nr,
|
||||
ImageGenerationToolHeartbeat(),
|
||||
writer,
|
||||
)
|
||||
elif tool_response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
response = cast(list[ImageGenerationResponse], tool_response.response)
|
||||
image_generation_responses = response
|
||||
break
|
||||
|
||||
# save images to file store
|
||||
file_ids = save_files(
|
||||
urls=[],
|
||||
base64_files=[img.image_data for img in image_generation_responses],
|
||||
)
|
||||
|
||||
final_generated_images = [
|
||||
GeneratedImage(
|
||||
file_id=file_id,
|
||||
url=build_frontend_file_url(file_id),
|
||||
revised_prompt=img.revised_prompt,
|
||||
shape=(requested_shape or ImageShape.SQUARE).value,
|
||||
)
|
||||
for file_id, img in zip(file_ids, image_generation_responses)
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Image generation complete for {iteration_nr}.{parallelization_nr} at {datetime.now()}"
|
||||
)
|
||||
|
||||
# Create answer string describing the generated images
|
||||
if final_generated_images:
|
||||
image_descriptions = []
|
||||
for i, img in enumerate(final_generated_images, 1):
|
||||
if img.shape and img.shape != ImageShape.SQUARE.value:
|
||||
image_descriptions.append(
|
||||
f"Image {i}: {img.revised_prompt} (shape: {img.shape})"
|
||||
)
|
||||
else:
|
||||
image_descriptions.append(f"Image {i}: {img.revised_prompt}")
|
||||
|
||||
answer_string = (
|
||||
f"Generated {len(final_generated_images)} image(s) based on the request: {image_prompt}\n\n"
|
||||
+ "\n".join(image_descriptions)
|
||||
)
|
||||
if requested_shape:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) in {requested_shape.value} orientation."
|
||||
)
|
||||
else:
|
||||
reasoning = (
|
||||
"Used image generation tool to create "
|
||||
f"{len(final_generated_images)} image(s) based on the user's request."
|
||||
)
|
||||
else:
|
||||
answer_string = f"Failed to generate images for request: {image_prompt}"
|
||||
reasoning = "Image generation tool did not return any results."
|
||||
|
||||
return BranchUpdate(
|
||||
branch_iteration_responses=[
|
||||
IterationAnswer(
|
||||
tool=image_tool_info.llm_path,
|
||||
tool_id=image_tool_info.tool_id,
|
||||
iteration_nr=iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
question=branch_query,
|
||||
answer=answer_string,
|
||||
claims=[],
|
||||
cited_documents={},
|
||||
reasoning=reasoning,
|
||||
generated_images=final_generated_images,
|
||||
)
|
||||
],
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="generating",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,71 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.dr.models import GeneratedImage
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentUpdate
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_reducer(
|
||||
state: SubAgentMainState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> SubAgentUpdate:
|
||||
"""
|
||||
LangGraph node to perform a standard search as part of the DR process.
|
||||
"""
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
branch_updates = state.branch_iteration_responses
|
||||
current_iteration = state.iteration_nr
|
||||
current_step_nr = state.current_step_nr
|
||||
|
||||
new_updates = [
|
||||
update for update in branch_updates if update.iteration_nr == current_iteration
|
||||
]
|
||||
generated_images: list[GeneratedImage] = []
|
||||
for update in new_updates:
|
||||
if update.generated_images:
|
||||
generated_images.extend(update.generated_images)
|
||||
|
||||
# Write the results to the stream
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
ImageGenerationToolDelta(
|
||||
images=generated_images,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
|
||||
write_custom_event(
|
||||
current_step_nr,
|
||||
SectionEnd(),
|
||||
writer,
|
||||
)
|
||||
|
||||
current_step_nr += 1
|
||||
|
||||
return SubAgentUpdate(
|
||||
iteration_responses=new_updates,
|
||||
current_step_nr=current_step_nr,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="image_generation",
|
||||
node_name="consolidation",
|
||||
node_start_time=node_start_time,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,29 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.dr.constants import MAX_DR_PARALLEL_SEARCH
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import BranchInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
|
||||
|
||||
def branching_router(state: SubAgentInput) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"act",
|
||||
BranchInput(
|
||||
iteration_nr=state.iteration_nr,
|
||||
parallelization_nr=parallelization_nr,
|
||||
branch_question=query,
|
||||
context="",
|
||||
active_source_types=state.active_source_types,
|
||||
tools_used=state.tools_used,
|
||||
available_tools=state.available_tools,
|
||||
assistant_system_prompt=state.assistant_system_prompt,
|
||||
assistant_task_prompt=state.assistant_task_prompt,
|
||||
),
|
||||
)
|
||||
for parallelization_nr, query in enumerate(
|
||||
state.query_list[:MAX_DR_PARALLEL_SEARCH]
|
||||
)
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_1_branch import (
|
||||
image_generation_branch,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_2_act import (
|
||||
image_generation,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_3_reduce import (
|
||||
is_reducer,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.image_generation.dr_image_generation_conditional_edges import (
|
||||
branching_router,
|
||||
)
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentInput
|
||||
from onyx.agents.agent_search.dr.sub_agents.states import SubAgentMainState
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def dr_image_generation_graph_builder() -> StateGraph:
|
||||
"""
|
||||
LangGraph graph builder for Image Generation Sub-Agent
|
||||
"""
|
||||
|
||||
graph = StateGraph(state_schema=SubAgentMainState, input=SubAgentInput)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node("branch", image_generation_branch)
|
||||
|
||||
graph.add_node("act", image_generation)
|
||||
|
||||
graph.add_node("reducer", is_reducer)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="branch")
|
||||
|
||||
graph.add_conditional_edges("branch", branching_router)
|
||||
|
||||
graph.add_edge(start_key="act", end_key="reducer")
|
||||
|
||||
graph.add_edge(start_key="reducer", end_key=END)
|
||||
|
||||
return graph
|
||||
@@ -1,13 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
file_id: str
|
||||
url: str
|
||||
revised_prompt: str
|
||||
shape: str | None = None
|
||||
|
||||
|
||||
# Needed for PydanticType
|
||||
class GeneratedImageFullResult(BaseModel):
|
||||
images: list[GeneratedImage]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user